diff --git a/environments/rlaif_server.py b/environments/rlaif_server.py index cfacab71..275ac786 100644 --- a/environments/rlaif_server.py +++ b/environments/rlaif_server.py @@ -225,7 +225,9 @@ class RLAIFEnv(BaseEnv): scores["tokens"].append(tokens) scores["masks"].append(masks) scores["inference_logprobs"].append(logprobs) - scores["scores"].append(1.0 if item["finish_reason"] != "length" else -1.0) + scores["scores"].append( + 1.0 if item["finish_reason"] != "length" else -1.0 + ) return scores else: fwd_fmt = RLAIF_user_prompt_format_str.format(