mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add printing
This commit is contained in:
parent
a11af27298
commit
f4de3ad6f5
2 changed files with 61 additions and 9 deletions
|
|
@ -155,10 +155,6 @@ class BaseEnvConfig(BaseModel):
|
|||
default=None,
|
||||
description="Path to save the groups, if set, will write groups to this jsonl",
|
||||
)
|
||||
data_dir_to_save_evals: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Directory to save evaluation results",
|
||||
)
|
||||
min_items_sent_before_logging: int = Field(
|
||||
default=2,
|
||||
description="Minimum number of items sent before logging, if 0 or less, logs every time",
|
||||
|
|
@ -649,6 +645,8 @@ class BaseEnv(ABC):
|
|||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
generation_parameters: Optional[Dict] = None,
|
||||
samples: Optional[List[Dict]] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Log evaluation results to a JSON file in the format expected by nous-evals.
|
||||
|
|
@ -660,6 +658,8 @@ class BaseEnv(ABC):
|
|||
start_time: Start time of evaluation (unix timestamp)
|
||||
end_time: End time of evaluation (unix timestamp)
|
||||
generation_parameters: Dictionary of generation parameters used
|
||||
samples: List of sample dictionaries to save to samples.jsonl
|
||||
verbose: If True, print a markdown table of the metrics
|
||||
"""
|
||||
if self.config.data_dir_to_save_evals is None:
|
||||
logger.warning("data_dir_to_save_evals is not set, skipping evaluation logging")
|
||||
|
|
@ -667,14 +667,14 @@ class BaseEnv(ABC):
|
|||
|
||||
import os
|
||||
import json
|
||||
import jsonlines
|
||||
from datetime import datetime
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")[:-3] + "-00-00"
|
||||
filename = f"{timestamp}.json"
|
||||
# Generate filename
|
||||
filename = "metrics.json"
|
||||
filepath = os.path.join(self.config.data_dir_to_save_evals, filename)
|
||||
|
||||
# Default values
|
||||
|
|
@ -700,6 +700,24 @@ class BaseEnv(ABC):
|
|||
# Merge config params with passed params (passed params take precedence)
|
||||
merged_gen_params = {**config_gen_params, **generation_parameters}
|
||||
|
||||
# Print metrics table if verbose
|
||||
if verbose:
|
||||
print("\n" + "="*60)
|
||||
print(f"Evaluation Results: {task_name}")
|
||||
print("="*60)
|
||||
print(f"|{'Groups':<20}|{'Version':<7}|{'Filter':<6}|{'n-shot':<6}|{'Metric':<10}|{' ':<3}|{'Value':<10}|{' ':<3}|{'Stderr':<10}|")
|
||||
print(f"|{'-'*20}|{'-'*7}:{'-'*6}|{'-'*6}|{'-'*10}|{'-'*3}|{'-'*10}:{'-'*3}|{'-'*10}:|")
|
||||
|
||||
# Main task row
|
||||
for metric_name, metric_value in metrics.items():
|
||||
clean_metric_name = metric_name.replace("eval/", "").replace("_", " ")
|
||||
direction = "↑" if "correct" in metric_name or "acc" in metric_name else " "
|
||||
print(f"|{task_name:<20}|{1:<7}|{'none':<6}|{'':<6}|{clean_metric_name:<10}|{direction:<3}|{metric_value:<10.4f}|{'±':<3}|{'0.0000':<10}|")
|
||||
|
||||
print("="*60)
|
||||
print(f"Evaluation completed in {end_time - start_time:.2f} seconds")
|
||||
print("="*60 + "\n")
|
||||
|
||||
# Build the evaluation result structure
|
||||
task_key = f"atropos|{task_name}|0"
|
||||
|
||||
|
|
@ -802,11 +820,19 @@ class BaseEnv(ABC):
|
|||
}
|
||||
}
|
||||
|
||||
# Write to file
|
||||
# Write main results to JSON file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(eval_result, f, indent=2)
|
||||
|
||||
print(f"Evaluation results saved to {filepath}")
|
||||
|
||||
# Write samples to JSONL file if provided
|
||||
if samples:
|
||||
samples_filepath = os.path.join(self.config.data_dir_to_save_evals, "samples.jsonl")
|
||||
with jsonlines.open(samples_filepath, 'w') as writer:
|
||||
for sample in samples:
|
||||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
|
|||
|
|
@ -159,13 +159,39 @@ class GSM8kEnv(BaseEnv):
|
|||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
percent_correct = sum(scores) / len(scores)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Add to existing metrics for wandb
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
# Log evaluation results
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
"eval/total_samples": len(scores),
|
||||
"eval/correct_samples": sum(scores),
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"split": "eval"
|
||||
}
|
||||
)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: GSM8kRow
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue