refactor: update score method to use LLM with detailed rubric for joke evaluation

This commit is contained in:
Kirill Igumenshchev (aider) 2025-05-18 17:48:14 -07:00
parent db1e68d2ab
commit 96043a968f

View file

@ -74,15 +74,42 @@ class HumorEnv(BaseEnv):
return scored, []
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
"""
Score each generated joke using the detailed rubric via an LLM call.
"""
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
for (messages, _), idx in zip(rollout_group_data, range(len(rollout_group_data))):
expected = self.train[idx % len(self.train)]["response"].strip()
output = messages[-1]["content"].strip()
score_val = 1.0 if output == expected else 0.0
out = tokenize_for_trainer(self.tokenizer, list(messages))
# All items share same comedian/format
fmt = self.train[0]["format"]
comedian = self.train[0]["comedian"]
for messages, _ in rollout_group_data:
joke = messages[-1]["content"].strip()
# Build the rubric prompt
rubric_prompt = (
f"1. Relevance to the format ({fmt}): Evaluate the joke \"{joke}\". Score: X (0-2)\n"
f"2. Style consistency ({comedian}): Evaluate the joke \"{joke}\". Score: X (0-2)\n"
f"3. Creativity: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
f"4. Humor effectiveness: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
f"5. Virality: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
f"6. Cognitive coherence: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
"Please provide each score on its own line as 'Score: <number>'."
)
judge = await self.server.chat_completion(
messages=[{"role": "user", "content": rubric_prompt}],
n=1,
max_tokens=512,
)
text = judge.choices[0].message.content
# Parse out all Score: X lines
nums = [
int(line.split("Score:")[-1].strip().split()[0])
for line in text.splitlines()
if "Score:" in line
]
avg_score = sum(nums) / len(nums) if nums else 0.0
out = tokenize_for_trainer(self.tokenizer, messages)
scores["tokens"].append(out["tokens"])
scores["masks"].append(out["masks"])
scores["scores"].append(score_val)
scores["scores"].append(avg_score)
return scores
async def wandb_log(self, wandb_metrics: Optional[dict] = None):