mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-30 17:40:40 +00:00
Merge pull request #11 from collinear-ai/feat/multi-episode
Add multi-episode setting with scratchpad carryover
This commit is contained in:
commit
d976b9cbb4
9 changed files with 490 additions and 82 deletions
|
|
@ -118,6 +118,7 @@ def run_agent_loop(
|
|||
max_turns: int | None = None,
|
||||
on_turn_start=None,
|
||||
on_turn=None,
|
||||
episode: int = 1,
|
||||
) -> RunState:
|
||||
run_state.start()
|
||||
turns_since_resume = 0 # consecutive turns without sim resume
|
||||
|
|
@ -136,7 +137,7 @@ def run_agent_loop(
|
|||
if run_state.turn_count == 0:
|
||||
with db_factory() as db:
|
||||
snapshot = _snapshot_state(db, company_id)
|
||||
user_input = build_initial_user_prompt(**snapshot)
|
||||
user_input = build_initial_user_prompt(**snapshot, episode=episode)
|
||||
elif run_state.next_user_input is not None:
|
||||
user_input = run_state.next_user_input
|
||||
run_state.next_user_input = None
|
||||
|
|
|
|||
|
|
@ -173,6 +173,7 @@ def build_initial_user_prompt(
|
|||
employee_count: int,
|
||||
monthly_payroll_cents: int,
|
||||
bankrupt: bool,
|
||||
episode: int = 1,
|
||||
) -> str:
|
||||
"""Build the one-time initial user message at run start."""
|
||||
runway_months = (
|
||||
|
|
@ -188,7 +189,18 @@ def build_initial_user_prompt(
|
|||
)
|
||||
runway_str = f"~{runway_months} months" if runway_months is not None else "∞"
|
||||
|
||||
lines = [
|
||||
lines = []
|
||||
if episode > 1:
|
||||
lines.extend([
|
||||
f"## Episode {episode} — Restarting After Bankruptcy",
|
||||
"",
|
||||
f"You went bankrupt in episode {episode - 1}. The simulation has been reset,",
|
||||
"but your **scratchpad notes from the previous episode are preserved**.",
|
||||
"Read your scratchpad (`yc-bench scratchpad read`) to review your notes",
|
||||
"and learn from past mistakes before taking action.",
|
||||
"",
|
||||
])
|
||||
lines.extend([
|
||||
"## Simulation Start — Take Immediate Action",
|
||||
f"- current_time: {sim_time}",
|
||||
f"- horizon_end: {horizon_end}",
|
||||
|
|
@ -210,7 +222,7 @@ def build_initial_user_prompt(
|
|||
"7. `yc-bench sim resume` — advance time to collect the first task completion event",
|
||||
"",
|
||||
"Do not spend multiple turns just browsing. Accept and dispatch tasks immediately.",
|
||||
]
|
||||
])
|
||||
if bankrupt:
|
||||
lines.append("WARNING: company is already bankrupt at initialization.")
|
||||
return "\n".join(lines)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,10 @@ class RunState:
|
|||
next_user_input: Optional[str] = None
|
||||
total_cost_usd: float = 0.0
|
||||
|
||||
# Multi-episode tracking
|
||||
current_episode: int = 1
|
||||
episode_results: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
def start(self) -> None:
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
|
@ -66,18 +70,14 @@ class RunState:
|
|||
return True
|
||||
return False
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
return {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
def finish_episode(self) -> Dict[str, Any]:
|
||||
"""Snapshot current episode state into episode_results and return it."""
|
||||
episode_data = {
|
||||
"episode": self.current_episode,
|
||||
"turns_completed": self.turn_count,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"total_cost_usd": round(self.total_cost_usd, 6),
|
||||
"cost_usd": round(self.total_cost_usd, 6),
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"transcript": [
|
||||
|
|
@ -91,11 +91,66 @@ class RunState:
|
|||
for t in self.transcript
|
||||
],
|
||||
}
|
||||
self.episode_results.append(episode_data)
|
||||
return episode_data
|
||||
|
||||
def reset_for_new_episode(self) -> None:
|
||||
"""Reset mutable state for a new episode, preserving episode_results."""
|
||||
self.current_episode += 1
|
||||
self.turn_count = 0
|
||||
self.terminal = False
|
||||
self.terminal_reason = None
|
||||
self.terminal_detail = None
|
||||
self.started_at = None
|
||||
self.ended_at = None
|
||||
self.transcript = []
|
||||
self.next_user_input = None
|
||||
self.total_cost_usd = 0.0
|
||||
|
||||
def full_rollout(self) -> Dict[str, Any]:
|
||||
"""Full results including transcript for saving to disk."""
|
||||
base = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"seed": self.seed,
|
||||
"horizon_years": self.horizon_years,
|
||||
"total_episodes": self.current_episode,
|
||||
"terminal": self.terminal,
|
||||
"terminal_reason": self.terminal_reason.value if self.terminal_reason else None,
|
||||
"terminal_detail": self.terminal_detail,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
}
|
||||
if self.episode_results:
|
||||
# Multi-episode: include all episode data
|
||||
total_turns = sum(ep["turns_completed"] for ep in self.episode_results)
|
||||
total_cost = sum(ep["cost_usd"] for ep in self.episode_results)
|
||||
base["turns_completed"] = total_turns
|
||||
base["total_cost_usd"] = round(total_cost, 6)
|
||||
base["episodes"] = self.episode_results
|
||||
else:
|
||||
# Single-episode (backward compat): flat transcript
|
||||
base["turns_completed"] = self.turn_count
|
||||
base["total_cost_usd"] = round(self.total_cost_usd, 6)
|
||||
base["transcript"] = [
|
||||
{
|
||||
"turn": t.turn,
|
||||
"timestamp": t.timestamp,
|
||||
"user_input": t.user_input,
|
||||
"agent_output": t.agent_output,
|
||||
"commands_executed": t.commands_executed,
|
||||
}
|
||||
for t in self.transcript
|
||||
]
|
||||
return base
|
||||
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
"""Summary without transcript for logging."""
|
||||
rollout = self.full_rollout()
|
||||
rollout.pop("transcript", None)
|
||||
if "episodes" in rollout:
|
||||
for ep in rollout["episodes"]:
|
||||
ep.pop("transcript", None)
|
||||
return rollout
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -99,6 +99,10 @@ def run_command_cli(
|
|||
help="Preset name ('default', 'fast_test', 'high_reward') or path to a .toml file",
|
||||
),
|
||||
no_live: bool = typer.Option(False, "--no-live", help="Disable the live terminal dashboard"),
|
||||
max_episodes: int = typer.Option(
|
||||
1, "--max-episodes",
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
),
|
||||
):
|
||||
"""Run a full benchmark: migrate DB, seed world, run agent loop to completion."""
|
||||
from dotenv import find_dotenv, load_dotenv
|
||||
|
|
@ -114,6 +118,7 @@ def run_command_cli(
|
|||
start_date=start_date,
|
||||
config_name=config_name,
|
||||
no_live=no_live,
|
||||
max_episodes=max_episodes,
|
||||
)
|
||||
raise SystemExit(run_benchmark(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class RunArgs:
|
|||
start_date: str
|
||||
config_name: str = "default"
|
||||
no_live: bool = False
|
||||
max_episodes: int = 1
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -32,6 +33,10 @@ def build_parser():
|
|||
"--no-live", action="store_true", default=False,
|
||||
help="Disable the live terminal dashboard (show raw log output instead)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes", type=int, default=1,
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
|
|
@ -46,10 +51,13 @@ def parse_run_args(argv):
|
|||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
no_live=ns.no_live,
|
||||
max_episodes=ns.max_episodes,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
if ns.max_episodes < 1:
|
||||
parser.error("--max-episodes must be int >= 1")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.run_state import RunState, TerminalReason
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
|
|
@ -154,6 +154,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None:
|
|||
lg.addHandler(file_handler)
|
||||
|
||||
|
||||
def _build_db_url(args, episode: int, max_episodes: int) -> str:
|
||||
"""Build SQLite DATABASE_URL, adding episode suffix when multi-episode."""
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
base = f"{args.config_name}_{args.seed}_{slug}"
|
||||
if max_episodes > 1:
|
||||
return f"sqlite:///{db_dir}/{base}.ep{episode}.db"
|
||||
return f"sqlite:///{db_dir}/{base}.db"
|
||||
|
||||
|
||||
def _read_scratchpad(db_factory, company_id) -> str:
|
||||
"""Read scratchpad content from the current DB."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
return row.content if row else ""
|
||||
|
||||
|
||||
def _write_scratchpad(db_factory, company_id, content: str) -> None:
|
||||
"""Write scratchpad content into the current DB (upsert)."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
if row is None:
|
||||
db.add(Scratchpad(company_id=company_id, content=content))
|
||||
else:
|
||||
row.content = content
|
||||
db.flush()
|
||||
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
|
|
@ -184,31 +215,11 @@ def run_benchmark(args):
|
|||
_redirect_all_logging_to_file(log_file)
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<config>_<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
# Build runtime settings (shared across episodes)
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
|
|
@ -222,7 +233,7 @@ def run_benchmark(args):
|
|||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
# Build run state (persists across episodes)
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
|
|
@ -231,66 +242,110 @@ def run_benchmark(args):
|
|||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Set up live dashboard (or not)
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 7. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
try:
|
||||
max_episodes = args.max_episodes
|
||||
carried_scratchpad = ""
|
||||
|
||||
for episode in range(1, max_episodes + 1):
|
||||
logger.info("=== Episode %d / %d ===", episode, max_episodes)
|
||||
|
||||
# 1. Build engine for this episode's DB
|
||||
db_url = _build_db_url(args, episode, max_episodes)
|
||||
os.environ["DATABASE_URL"] = db_url
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory(_factory=factory):
|
||||
with session_scope(_factory) as session:
|
||||
yield session
|
||||
|
||||
# 2. Init simulation
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 3. Restore scratchpad from previous episode
|
||||
if episode > 1 and carried_scratchpad:
|
||||
_write_scratchpad(db_factory, company_id, carried_scratchpad)
|
||||
logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad))
|
||||
|
||||
# 4. Set up live dashboard
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 5. Run agent loop for this episode
|
||||
try:
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
episode=episode,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
dashboard.print_final_summary(final_state)
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
# 6. For multi-episode runs, snapshot this episode's data
|
||||
if max_episodes > 1:
|
||||
run_state.finish_episode()
|
||||
|
||||
# 8. Print final summary
|
||||
if dashboard is not None:
|
||||
dashboard.print_final_summary(final_state)
|
||||
logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason)
|
||||
|
||||
# 9. Save full rollout (with transcript) and print summary
|
||||
# 7. If not bankruptcy, or last episode, stop
|
||||
if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes:
|
||||
break
|
||||
|
||||
# 8. Save scratchpad for next episode, then reset
|
||||
carried_scratchpad = _read_scratchpad(db_factory, company_id)
|
||||
logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad))
|
||||
|
||||
# Clear runtime session (fresh conversation history)
|
||||
runtime.clear_session(session_id)
|
||||
run_state.reset_for_new_episode()
|
||||
engine.dispose()
|
||||
|
||||
# 9. Save full rollout and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
rollout["time_series"] = extract_time_series(db_factory, company_id)
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
logger.info("Full rollout written to %s", results_path)
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue