# Copyright 2025 - Oumi## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcsvimportpandasaspdPARQUET_EXTENSION=".parquet"
[docs]defsave_infer_prob(output_filepath:str,probabilities:list[list[list[float]]]):"""Save batched probabilities into a parquet file."""df_probs=pd.DataFrame(probabilities)df_probs.to_parquet(f"{output_filepath}{PARQUET_EXTENSION}")
[docs]defload_infer_prob(input_filepath:str)->list[list[list[float]]]:"""Retrieve batched probabilities from a parquet file."""probs_count_in_first_batch=Nonedefto_list(probs):"""Ensure number of probabilities is the same for all entries."""probs_list=list(probs)nonlocalprobs_count_in_first_batchprobs_count_in_first_batch=probs_count_in_first_batchorlen(probs_list)ifprobs_count_in_first_batch!=len(probs_list):raiseValueError(f"Reading `{input_filepath}{PARQUET_EXTENSION}`: inconsistent number of"f"probs across entries: len({probs_list})!={probs_count_in_first_batch}")returnprobs_listdf_probs=pd.read_parquet(f"{input_filepath}{PARQUET_EXTENSION}")probabilities=df_probs.to_numpy().tolist()probabilities=[[to_list(probs)forprobsinbatch]forbatchinprobabilities]returnprobabilities
# The inference probabilities (`probabilities`) are structured as follows:# (the example below assumes 4 batches of batch_size=2 and, for each of these,# 4 probabilities corresponding to the multiple choices A, B, C, D)## [# [ <-- batch no 0:# [p_0_0_A, p_0_0_B, p_0_0_C, p_0_0_D], <-- batch index = 0# [p_0_1_A, p_0_1_B, p_0_1_C, p_0_1_D], <-- batch index = 1# ],# [ <-- batch no 1:# [p_1_0_A, p_1_0_B, p_1_0_C, p_1_0_D], <-- batch index = 0# [p_1_1_A, p_1_1_B, p_1_1_C, p_1_1_D], <-- batch index = 1# ],# [ <-- batch no 2:# [p_2_0_A, p_2_0_B, p_2_0_C, p_2_0_D], <-- batch index = 0# [p_2_1_A, p_2_1_B, p_2_1_C, p_2_1_D], <-- batch index = 1# ],# [ <-- batch no 3:# [p_3_0_A, p_3_0_B, p_3_0_C, p_3_0_D], <-- batch index = 0# [p_3_1_A, p_3_1_B, p_3_1_C, p_3_1_D], <-- batch index = 1# ]# ]## We save these into a .csv file of the following format:# - Every row corresponds to a batch.# - Within each row, the batch items are strings separated by comma (,).# - Each item (string) contains a list of probabilities (floats).## batch index = 0 batch index = 1 batch no# <--------------------------------> , <--------------------------------> |# "[p_0_0_A, p_0_0_B, p_0_0_C, p_0_0_D]","[p_0_1_A, p_0_1_B, p_0_1_C, p_0_1_D]" <--0# "[p_1_0_A, p_1_0_B, p_1_0_C, p_1_0_D]","[p_1_1_A, p_1_1_B, p_1_1_C, p_1_1_D]" <--1# "[p_2_0_A, p_2_0_B, p_2_0_C, p_2_0_D]","[p_2_1_A, p_2_1_B, p_2_1_C, p_2_1_D]" <--2# "[p_3_0_A, p_3_0_B, p_3_0_C, p_3_0_D]","[p_3_1_A, p_3_1_B, p_3_1_C, p_3_1_D]" <--3#
[docs]defsave_infer_prob_csv(output_filepath:str,probabilities:list[list[list[float]]]):"""Save batched probabilities into a csv file."""withopen(output_filepath,"w")aswrite_obj:csv_writer=csv.writer(write_obj)csv_writer.writerows(probabilities)
[docs]defload_infer_prob_csv(input_filepath:str)->list[list[list[float]]]:"""Retrieve batched probabilities from a csv file."""probs_count_in_first_batch=Nonetry:withopen(input_filepath)asread_obj:csv_reader=csv.reader(read_obj)probabilities=[]forbatchincsv_reader:probabilities_batch=[]forentryinbatch:probs_list=str_to_float_list(entry)# Number of probabilities must be the same for all entries.probs_count_in_first_batch=probs_count_in_first_batchorlen(probs_list)ifprobs_count_in_first_batch!=len(probs_list):raiseValueError(f"Reading {input_filepath}: inconsistent number of probs "f"across entries: len({probs_list}) != "f"{probs_count_in_first_batch}")probabilities_batch.append(probs_list)probabilities.append(probabilities_batch)returnprobabilitiesexceptFileNotFoundError:raiseFileNotFoundError(f"{load_infer_prob}: Path {input_filepath} not found!")
[docs]defstr_to_float_list(input:str)->list[float]:"""Convert an `str` representing a list of `floats` to an actual list of `floats`. Example: input: `[1.1, 2.2, 3.3]` => output: [1.1, 2.2, 3.3] """# 1) Get rid of '[' and ']'.if(input[0]!="[")or(input[-1]!="]"):raiseValueError(f"Input `{input}` must start with '[' and end with ']' to represent a list")input=input[1:-1]# 2) Convert string to a list of items.list_of_items=input.split(", ")ifnotlen(list_of_items):raiseValueError(f"List `{list_of_items}` does NOT contain any items")# 3) Cast all list items to `float`.try:list_of_floats=[float(item)foriteminlist_of_items]exceptValueError:raiseValueError(f"List `{list_of_items}` should contain probabilities")returnlist_of_floats