mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
refactor: update score method to use LLM with detailed rubric for joke evaluation
This commit is contained in:
parent
db1e68d2ab
commit
96043a968f
1 changed files with 33 additions and 6 deletions
|
|
@ -74,15 +74,42 @@ class HumorEnv(BaseEnv):
|
|||
return scored, []
|
||||
|
||||
async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]:
|
||||
"""
|
||||
Score each generated joke using the detailed rubric via an LLM call.
|
||||
"""
|
||||
scores = ScoredDataGroup(tokens=[], masks=[], scores=[])
|
||||
for (messages, _), idx in zip(rollout_group_data, range(len(rollout_group_data))):
|
||||
expected = self.train[idx % len(self.train)]["response"].strip()
|
||||
output = messages[-1]["content"].strip()
|
||||
score_val = 1.0 if output == expected else 0.0
|
||||
out = tokenize_for_trainer(self.tokenizer, list(messages))
|
||||
# All items share same comedian/format
|
||||
fmt = self.train[0]["format"]
|
||||
comedian = self.train[0]["comedian"]
|
||||
for messages, _ in rollout_group_data:
|
||||
joke = messages[-1]["content"].strip()
|
||||
# Build the rubric prompt
|
||||
rubric_prompt = (
|
||||
f"1. Relevance to the format ({fmt}): Evaluate the joke \"{joke}\". Score: X (0-2)\n"
|
||||
f"2. Style consistency ({comedian}): Evaluate the joke \"{joke}\". Score: X (0-2)\n"
|
||||
f"3. Creativity: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
|
||||
f"4. Humor effectiveness: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
|
||||
f"5. Virality: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
|
||||
f"6. Cognitive coherence: Evaluate the joke \"{joke}\". Score: X (0-3)\n"
|
||||
"Please provide each score on its own line as 'Score: <number>'."
|
||||
)
|
||||
judge = await self.server.chat_completion(
|
||||
messages=[{"role": "user", "content": rubric_prompt}],
|
||||
n=1,
|
||||
max_tokens=512,
|
||||
)
|
||||
text = judge.choices[0].message.content
|
||||
# Parse out all Score: X lines
|
||||
nums = [
|
||||
int(line.split("Score:")[-1].strip().split()[0])
|
||||
for line in text.splitlines()
|
||||
if "Score:" in line
|
||||
]
|
||||
avg_score = sum(nums) / len(nums) if nums else 0.0
|
||||
out = tokenize_for_trainer(self.tokenizer, messages)
|
||||
scores["tokens"].append(out["tokens"])
|
||||
scores["masks"].append(out["masks"])
|
||||
scores["scores"].append(score_val)
|
||||
scores["scores"].append(avg_score)
|
||||
return scores
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[dict] = None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue