<div class="align-center">
<a href="https://oumi.ai/"><img src="https://oumi.ai/docs/en/latest/_static/logo/header_logo.png" height="200"></a>

[![Documentation](https://img.shields.io/badge/Documentation-latest-blue.svg)](https://oumi.ai/docs/en/latest/index.html)
[![Discord](https://img.shields.io/discord/1286348126797430814?label=Discord)](https://discord.gg/oumi)
[![GitHub Repo stars](https://img.shields.io/github/stars/oumi-ai/oumi)](https://github.com/oumi-ai/oumi)
</div>

üëã Welcome to Open Universal Machine Intelligence (Oumi)!

üöÄ Oumi is a fully open-source platform that streamlines the entire lifecycle of foundation models - from [data preparation](https://oumi.ai/docs/en/latest/resources/datasets/datasets.html) and [training](hhttps://oumi.ai/docs/en/latest/user_guides/train/train.html) to [evaluation](https://oumi.ai/docs/en/latest/user_guides/evaluate/evaluate.html) and [deployment](https://oumi.ai/docs/en/latest/user_guides/launch/launch.html). Whether you're developing on a laptop, launching large scale experiments on a cluster, or deploying models in production, Oumi provides the tools and workflows you need.

ü§ù Make sure to join our [Discord community](https://discord.gg/oumi) to get help, share your experiences, and contribute to the project! If you are interested in joining one of the community's open-science efforts, check out our [open collaboration](https://oumi.ai/community) page.

‚≠ê If you like Oumi and you would like to support it, please give it a star on [GitHub](https://github.com/oumi-ai/oumi).

# Custom Judge

Our platform offers [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20Oumi%20Judge.ipynb), a judge that can help you filter examples that are non-helpful, dishonest, or unsafe out of your training dataset. Alternatively, our platform enables you to define a custom judge that can label and filter user-defined attributes, as well as user templates and custom parsing logic for judgements. This notebook demonstates how to build a custom judge.

### Problem Statement

Suppose that, for the sake of this tutorial, you live in a country where bananas are strictly illegal and talking about them is a crime. You want to finetune a LLM with a training dataset, but you want to ensure that the resulting model will never mention bananas in its response. You decide to build a custom judge that labels all the examples that mention bananas (either directly or indirectly), so that you can filter them our of your data. 

### Attribute Definition

In order to define a custom judge, we first need to define the `attribute` it will judge. This requires the following definitions:
1. Message: Structure of the `Message` that will be judged.
2. System Instruction: Instruction to request that the model operates as a judge.
3. Judgement: Structure of the `Message` that includes the judgement and how it can be parsed to extract the label.
4. (Optional) Few-shot examples. 

##### 1. Message (to be judged)

Let's name the message to be judged `BananaMessage` and ensure it inherits from `TemplatedMessage`. This is a message intended to train a model, so it must consist of a request (by a user) and a reponse (by the model, i.e., the AI assistant). We could alternatively only include the AI assistant's response, but combining it with the request will provide more context to the judge. 
- The required `role` parameter must be set to `USER`. This is the default setting for the messages to be judged.
- The parameters `request` and `response` are user-defined and optional. We can have as many user-defined parameters as we want and name them any way we want, in order to capture all information that must be embedded in each message. 
- The required `template` parameter (a [Jinja2 Template](https://pypi.org/project/Jinja2/)) describes how the user-defined parameters (`request` and `response` in our specific example) will be combined to produce the message to be judged. 

In [1]:
from oumi.core.types.conversation import Role, TemplatedMessage


class BananaMessage(TemplatedMessage):
    role: Role = Role.USER
    request: str
    response: str
    template: str = "<request>{{request}}</request><response>{{response}}</response>"

We can generate a `BananaMessage` and inspect it as follows:

In [2]:
my_message = BananaMessage(request="What is your favorite fruit?", response="Banana!")
print(my_message.content)

<request>What is your favorite fruit?</request><response>Banana!</response>


##### 2. System Instruction

Next, we must define a message, to be used as a System Instruction, which will instruct the model to operate as a judge. 

In [3]:
judge_SI = (
    "Act as a judge and try to identify the illegal word 'banana' in the text. "
    "Specifically, if bananas are mentioned in the response, either directly "
    "or indirectly (for example: by describing how they look rather than "
    "explicitly mentioning the word 'bananas'), your answer should be 'True'. "
    "If there is absolutely no mention of bananas, your answer should be 'False'. "
    "Only respond with 'True' or 'False', no other information is nessesary."
)

##### 3. Judgement

Next, we must define how the answer from the judge should be structured and how to parse it to extract the label. For this, we define the `BananaJudgement` class, which must inherit from `BaseJudgeOutput`. This class is also a `TemplatedMessage` and already includes the variable `raw_judgement`; this is where the judge's reponse will be written to. Similarly to `BananaMessage`, we can define a template in XML format (or whatever format makes sense for our use case) to easily extract the `label` from the judge's response (`raw_judgement`). But, for simplicity, we set no template (no special characters or words to wrap the `raw_judgement`), which means that we teach the judge to respond directly with the label. This is consistent with our `judge_SI` (above), where we request that the model only respond with `True` or `False`.

Note: If, for instance, we wanted to also include an explanation, we would need to define an additional `explanation` field, combine it with `raw_judgement` in the `template`, and finally define the `_transform_model_output` for our custom judge. `_transform_model_output` describes how to extract the label (and the other custom fields such as our `explanation`) from the `raw_judgement`. See [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/src/oumi/judges/oumi_judge.py) for reference.

In [4]:
from oumi.judges.base_judge import BaseJudgeOutput


class BananaJudgement(BaseJudgeOutput):
    role: Role = Role.ASSISTANT
    template: str = "{{ raw_judgement }}"

##### 4. Few-shot examples

Finally let's (optionally) define few-shot examples, which help our custom judge understand its task, as well as the format it should respond with. 

In [5]:
few_shot_examples = [
    BananaMessage(
        request="How does your favorite fruit look like?",
        response="It's curved and yellow with a thick skin and soft sweet flesh",
    ),
    BananaJudgement(
        raw_judgement="True",
    ),
]

##### Putting everything together

We can now define an attribute (`JudgeAttribute`) to describe to our custom judge how to judge, as follows. 

Note that we have set the `value_type` (i.e., the type of the `label` extracted from the judge's response) as `BOOL`. Other options are `CATEGORICAL` and `LIKERT_5`. For more details see our [judge config](https://github.com/oumi-ai/oumi/blob/main/src/oumi/core/configs/judge_config.py).

In [6]:
from oumi.core.configs.judge_config import JudgeAttribute, JudgeAttributeValueType

banana_attribute = JudgeAttribute(
    name="banana_attribute",
    system_prompt=judge_SI,
    examples=few_shot_examples,
    value_type=JudgeAttributeValueType.BOOL,
)

### Judge Config

After defining the attribute(s) that our judge will label (`banana_attribute`), we also need to define the underlying model that we will use for inference. Specifically, we need to provide the `ModelParams` and `GenerationConfig`. For more details on these, please refer to our [Oumi judge](https://github.com/oumi-ai/oumi/blob/main/notebooks/Oumi%20-%20Oumi%20Judge.ipynb) notebook. Once we define these parameters, we have a `JudgeConfig` that fully describes the configuration of our judge.

In [7]:
from oumi.core.configs import GenerationParams, JudgeConfig, ModelParams

banana_config = JudgeConfig(
    attributes={"banana_attribute": banana_attribute},
    model=ModelParams(model_name="HuggingFaceTB/SmolLM2-135M-Instruct"),
    generation=GenerationParams(max_new_tokens=1024),
)

### Judge Definition

The final step is to define our custom judge (`BananaJudge`) which should inherit from `BaseJudge`. There are 3 functions that we can optionally define, which describe how the judge will convert its input data and/or the judgment.

- `_transform_conversation_input`: Defines how to convert our input (`BananaMessage`) if this is provided as `oumi.core.types.turn.Conversation`. This function is not nessesary in our current implementation, since our input is a `TemplatedMessage`.
- `_transform_dict_input`: Defines how to convert our input (`BananaMessage`) if this is provided as `dict`. Again, this function is not nessesary in our current implementation, since our input is a `TemplatedMessage`.
- `_transform_model_output`: Defines how to parse `raw_judgement` to multiple user-defined fields (which can be useful to the user to better undertand the label) and how to extract the `label`. Since in our implementation we train the judge to directly respond with the `label`, this function is not really needed. All we need to do is "pass through" the `model_output` to `raw_judgement`. 

Once `BananaJudge` is defined, we can instantiate it, using our `JudgeConfig` (`banana_config`). 

In [8]:
from oumi.judges.base_judge import BaseJudge


class BananaJudge(BaseJudge):
    def _transform_conversation_input(self, conversation):
        raise NotImplementedError

    def _transform_dict_input(self, raw_input):
        raise NotImplementedError

    def _transform_model_output(self, model_output):
        return BananaJudgement(raw_judgement=model_output)


my_banana_judge = BananaJudge(config=banana_config)

[2025-01-16 15:48:03,229][oumi][rank0][pid:79301][MainThread][INFO]][models.py:174] Building model using device_map: auto (DeviceRankInfo(world_size=1, rank=0, local_world_size=1, local_rank=0))...
[2025-01-16 15:48:03,335][oumi][rank0][pid:79301][MainThread][INFO]][models.py:244] Using model class: <class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'> to instantiate model.



### Judge Inference

We can now call our `BananaJudge`'s `judge` method to judge a list of messages (i.e., our training dataset)

In [9]:
training_dataset = [
    BananaMessage(
        request="Do you like apples?",
        response="Not as much as I like bananas and kiwis.",
    ),
    BananaMessage(
        request="What did you eat earlier?",
        response="I ate a yellow tropical fruit, similar to plantain",
    ),
    BananaMessage(
        request="Do you like hiking?",
        response="I love all sorts of sports.",
    ),
]

judge_output = my_banana_judge.judge(training_dataset)

Generating Model Responses: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:02<00:00,  1.25it/s]


### Inspecting the judgments

Let's inspect the judge's output. The first two conversations mention (explicitly and implicitly respectively) bananas, thus they should be filtered. Our judge indicates this by setting the `label` to `True`. The third conversation is related to sports; thus our judge sets the `label` to `False`, indicating that it is safe to train with it. 

In [10]:
def inspect_judge_output(judge_output, training_dataset):
    """Prints the judge output in a human-readable format."""
    for conversation, judgement in zip(training_dataset, judge_output):
        print("Input:", conversation.content)
        print("Judgement:", judgement["banana_attribute"]["label"])


inspect_judge_output(judge_output, training_dataset)

Input: <request>Do you like apples?</request><response>Not as much as I like bananas and kiwis.</response>
Judgement: True
Input: <request>What did you eat earlier?</request><response>I ate a yellow tropical fruit, similar to plantain</response>
Judgement: True
Input: <request>Do you like hiking?</request><response>I love all sorts of sports.</response>
Judgement: False


### Conclusion

Thanks to your judge's hard work, your trained model's conversations will be banana-free! Remember that the judge is as powerful as the underlying model (Qwen2 0.5B in this tutorial), so if the quality of judgments seems unsatisfactory, consider experimenting with larger models. 