mirror of
https://github.com/collinear-ai/yc-bench.git
synced 2026-04-26 17:13:13 +00:00
Merge pull request #11 from collinear-ai/feat/multi-episode
Add multi-episode setting with scratchpad carryover
This commit is contained in:
commit
d976b9cbb4
9 changed files with 490 additions and 82 deletions
|
|
@ -13,6 +13,7 @@ class RunArgs:
|
|||
start_date: str
|
||||
config_name: str = "default"
|
||||
no_live: bool = False
|
||||
max_episodes: int = 1
|
||||
|
||||
def build_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -32,6 +33,10 @@ def build_parser():
|
|||
"--no-live", action="store_true", default=False,
|
||||
help="Disable the live terminal dashboard (show raw log output instead)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-episodes", type=int, default=1,
|
||||
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
|
||||
)
|
||||
return parser
|
||||
|
||||
def parse_run_args(argv):
|
||||
|
|
@ -46,10 +51,13 @@ def parse_run_args(argv):
|
|||
start_date=ns.start_date,
|
||||
config_name=ns.config_name,
|
||||
no_live=ns.no_live,
|
||||
max_episodes=ns.max_episodes,
|
||||
)
|
||||
|
||||
def _validate(ns, parser):
|
||||
if ns.horizon_years is not None and ns.horizon_years <= 0:
|
||||
parser.error("--horizon-years must be int > 0")
|
||||
if ns.max_episodes < 1:
|
||||
parser.error("--max-episodes must be int >= 1")
|
||||
|
||||
__all__ = ["RunArgs", "build_parser", "parse_run_args"]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from ..agent.commands.executor import run_command
|
||||
from ..agent.loop import run_agent_loop
|
||||
from ..agent.run_state import RunState
|
||||
from ..agent.run_state import RunState, TerminalReason
|
||||
from ..agent.runtime.factory import build_runtime
|
||||
from ..agent.runtime.schemas import RuntimeSettings
|
||||
from ..db.session import build_engine, build_session_factory, session_scope, init_db
|
||||
|
|
@ -154,6 +154,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None:
|
|||
lg.addHandler(file_handler)
|
||||
|
||||
|
||||
def _build_db_url(args, episode: int, max_episodes: int) -> str:
|
||||
"""Build SQLite DATABASE_URL, adding episode suffix when multi-episode."""
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
base = f"{args.config_name}_{args.seed}_{slug}"
|
||||
if max_episodes > 1:
|
||||
return f"sqlite:///{db_dir}/{base}.ep{episode}.db"
|
||||
return f"sqlite:///{db_dir}/{base}.db"
|
||||
|
||||
|
||||
def _read_scratchpad(db_factory, company_id) -> str:
|
||||
"""Read scratchpad content from the current DB."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
return row.content if row else ""
|
||||
|
||||
|
||||
def _write_scratchpad(db_factory, company_id, content: str) -> None:
|
||||
"""Write scratchpad content into the current DB (upsert)."""
|
||||
from ..db.models.scratchpad import Scratchpad
|
||||
with db_factory() as db:
|
||||
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
|
||||
if row is None:
|
||||
db.add(Scratchpad(company_id=company_id, content=content))
|
||||
else:
|
||||
row.content = content
|
||||
db.flush()
|
||||
|
||||
|
||||
def run_benchmark(args):
|
||||
"""Run a full benchmark: migrate, seed, loop until terminal."""
|
||||
logging.basicConfig(
|
||||
|
|
@ -184,31 +215,11 @@ def run_benchmark(args):
|
|||
_redirect_all_logging_to_file(log_file)
|
||||
|
||||
logger.info(
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years,
|
||||
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d",
|
||||
experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes,
|
||||
)
|
||||
|
||||
# 1. Build engine and create all tables
|
||||
# If DATABASE_URL is not explicitly set, default to db/<config>_<seed>_<slug>.db
|
||||
if not os.environ.get("DATABASE_URL"):
|
||||
slug = args.model.replace("/", "_")
|
||||
db_dir = Path("db")
|
||||
db_dir.mkdir(exist_ok=True)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db"
|
||||
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory():
|
||||
with session_scope(factory) as session:
|
||||
yield session
|
||||
|
||||
# 3. Init simulation using experiment world config
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 4. Build runtime settings from experiment AgentConfig
|
||||
# Build runtime settings (shared across episodes)
|
||||
agent_cfg = experiment_cfg.agent
|
||||
settings = RuntimeSettings(
|
||||
model=agent_cfg.model,
|
||||
|
|
@ -222,7 +233,7 @@ def run_benchmark(args):
|
|||
)
|
||||
runtime = build_runtime(settings, command_executor=run_command)
|
||||
|
||||
# 5. Build run state
|
||||
# Build run state (persists across episodes)
|
||||
session_id = f"run-{args.seed}-{args.model}"
|
||||
run_state = RunState(
|
||||
session_id=session_id,
|
||||
|
|
@ -231,66 +242,110 @@ def run_benchmark(args):
|
|||
horizon_years=horizon_years,
|
||||
)
|
||||
|
||||
# 6. Set up live dashboard (or not)
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 7. Run agent loop
|
||||
loop_cfg = experiment_cfg.loop
|
||||
try:
|
||||
max_episodes = args.max_episodes
|
||||
carried_scratchpad = ""
|
||||
|
||||
for episode in range(1, max_episodes + 1):
|
||||
logger.info("=== Episode %d / %d ===", episode, max_episodes)
|
||||
|
||||
# 1. Build engine for this episode's DB
|
||||
db_url = _build_db_url(args, episode, max_episodes)
|
||||
os.environ["DATABASE_URL"] = db_url
|
||||
engine = build_engine()
|
||||
init_db(engine)
|
||||
factory = build_session_factory(engine)
|
||||
|
||||
@contextmanager
|
||||
def db_factory(_factory=factory):
|
||||
with session_scope(_factory) as session:
|
||||
yield session
|
||||
|
||||
# 2. Init simulation
|
||||
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
|
||||
|
||||
# 3. Restore scratchpad from previous episode
|
||||
if episode > 1 and carried_scratchpad:
|
||||
_write_scratchpad(db_factory, company_id, carried_scratchpad)
|
||||
logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad))
|
||||
|
||||
# 4. Set up live dashboard
|
||||
dashboard = None
|
||||
on_turn_start = None
|
||||
on_turn = None
|
||||
|
||||
if use_live:
|
||||
from .dashboard import BenchmarkDashboard
|
||||
|
||||
dashboard = BenchmarkDashboard(
|
||||
model=args.model,
|
||||
seed=args.seed,
|
||||
config_name=args.config_name,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
)
|
||||
|
||||
def on_turn_start(turn_num):
|
||||
dashboard.mark_turn_start(turn_num)
|
||||
|
||||
def on_turn(snapshot, rs, commands):
|
||||
dashboard.update(snapshot, rs, commands)
|
||||
|
||||
# 5. Run agent loop for this episode
|
||||
try:
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
episode=episode,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
|
||||
if dashboard is not None:
|
||||
dashboard.start()
|
||||
dashboard.print_final_summary(final_state)
|
||||
|
||||
final_state = run_agent_loop(
|
||||
runtime=runtime,
|
||||
db_factory=db_factory,
|
||||
company_id=company_id,
|
||||
run_state=run_state,
|
||||
command_executor=run_command,
|
||||
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
|
||||
max_turns=loop_cfg.max_turns,
|
||||
on_turn_start=on_turn_start,
|
||||
on_turn=on_turn,
|
||||
)
|
||||
finally:
|
||||
if dashboard is not None:
|
||||
dashboard.stop()
|
||||
# 6. For multi-episode runs, snapshot this episode's data
|
||||
if max_episodes > 1:
|
||||
run_state.finish_episode()
|
||||
|
||||
# 8. Print final summary
|
||||
if dashboard is not None:
|
||||
dashboard.print_final_summary(final_state)
|
||||
logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason)
|
||||
|
||||
# 9. Save full rollout (with transcript) and print summary
|
||||
# 7. If not bankruptcy, or last episode, stop
|
||||
if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes:
|
||||
break
|
||||
|
||||
# 8. Save scratchpad for next episode, then reset
|
||||
carried_scratchpad = _read_scratchpad(db_factory, company_id)
|
||||
logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad))
|
||||
|
||||
# Clear runtime session (fresh conversation history)
|
||||
runtime.clear_session(session_id)
|
||||
run_state.reset_for_new_episode()
|
||||
engine.dispose()
|
||||
|
||||
# 9. Save full rollout and print summary
|
||||
rollout = final_state.full_rollout()
|
||||
rollout["time_series"] = extract_time_series(db_factory, company_id)
|
||||
summary = final_state.summary()
|
||||
logger.info("Run complete: %s", json.dumps(summary, indent=2))
|
||||
|
||||
# Write full rollout (includes transcript with commands)
|
||||
slug = args.model.replace("/", "_")
|
||||
results_dir = Path("results")
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json"
|
||||
results_path.write_text(json.dumps(rollout, indent=2))
|
||||
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
|
||||
logger.info("Full rollout written to %s", results_path)
|
||||
|
||||
return 0 if final_state.terminal_reason != "error" else 1
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue