mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
d07ab3e3ce
commit
5cfd1929f1
19 changed files with 708 additions and 452 deletions
|
|
@ -373,7 +373,9 @@ class GSM8KEvalEnv(BaseEnv):
|
|||
|
||||
# Create evaluation tasks
|
||||
async def eval_task(item):
|
||||
return await self.rollout_and_score_eval(item, self.server.servers[0].config)
|
||||
return await self.rollout_and_score_eval(
|
||||
item, self.server.servers[0].config
|
||||
)
|
||||
|
||||
tasks = [eval_task(item) for item in self.eval_items]
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class MathEnv(BaseEnv):
|
|||
vllm_url = os.environ.get("MATH_ENV_VLLM_URL", "http://localhost:9001/v1")
|
||||
wandb_name = os.environ.get("MATH_ENV_WANDB_NAME", "math-zero-env")
|
||||
max_token_length = int(os.environ.get("MATH_ENV_MAX_TOKENS", "8192"))
|
||||
|
||||
|
||||
env_config = RSConfig(
|
||||
tokenizer_name=model_name,
|
||||
group_size=8,
|
||||
|
|
@ -299,6 +299,7 @@ class MathEnv(BaseEnv):
|
|||
if not self.config.run_evaluation:
|
||||
return
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
|
|
@ -320,9 +321,7 @@ class MathEnv(BaseEnv):
|
|||
metrics[f"{subset}_accuracy"] = accuracy
|
||||
metrics[f"{subset}_total"] = len(scores)
|
||||
metrics[f"{subset}_correct"] = sum(scores)
|
||||
self.eval_metrics.append(
|
||||
(f"eval/{subset}_percent_correct", accuracy)
|
||||
)
|
||||
self.eval_metrics.append((f"eval/{subset}_percent_correct", accuracy))
|
||||
|
||||
# overall score
|
||||
all_scores = []
|
||||
|
|
@ -332,9 +331,7 @@ class MathEnv(BaseEnv):
|
|||
metrics["overall_accuracy"] = overall_accuracy
|
||||
metrics["overall_total"] = len(all_scores)
|
||||
metrics["overall_correct"] = sum(all_scores)
|
||||
self.eval_metrics.append(
|
||||
("eval/overall_percent_correct", overall_accuracy)
|
||||
)
|
||||
self.eval_metrics.append(("eval/overall_percent_correct", overall_accuracy))
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
|
@ -342,7 +339,9 @@ class MathEnv(BaseEnv):
|
|||
print("\n" + "=" * 60)
|
||||
print("Math Zero Evaluation Results")
|
||||
print("=" * 60)
|
||||
print(f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})")
|
||||
print(
|
||||
f"Overall Accuracy: {overall_accuracy:.2%} ({sum(all_scores)}/{len(all_scores)})"
|
||||
)
|
||||
print("\nPer-subset breakdown:")
|
||||
for subset, scores in sorted(task_lists.items()):
|
||||
acc = sum(scores) / len(scores)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue