Merge pull request #11 from collinear-ai/feat/multi-episode

Add multi-episode setting with scratchpad carryover
This commit is contained in:
Adit Jain 2026-03-13 18:21:37 -07:00 committed by GitHub
commit d976b9cbb4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 490 additions and 82 deletions

View file

@ -13,6 +13,7 @@ class RunArgs:
start_date: str
config_name: str = "default"
no_live: bool = False
max_episodes: int = 1
def build_parser():
parser = argparse.ArgumentParser(
@ -32,6 +33,10 @@ def build_parser():
"--no-live", action="store_true", default=False,
help="Disable the live terminal dashboard (show raw log output instead)",
)
parser.add_argument(
"--max-episodes", type=int, default=1,
help="Max episodes (restarts after bankruptcy with scratchpad carried over). Default: 1",
)
return parser
def parse_run_args(argv):
@ -46,10 +51,13 @@ def parse_run_args(argv):
start_date=ns.start_date,
config_name=ns.config_name,
no_live=ns.no_live,
max_episodes=ns.max_episodes,
)
def _validate(ns, parser):
if ns.horizon_years is not None and ns.horizon_years <= 0:
parser.error("--horizon-years must be int > 0")
if ns.max_episodes < 1:
parser.error("--max-episodes must be int >= 1")
__all__ = ["RunArgs", "build_parser", "parse_run_args"]

View file

@ -11,7 +11,7 @@ from pathlib import Path
from ..agent.commands.executor import run_command
from ..agent.loop import run_agent_loop
from ..agent.run_state import RunState
from ..agent.run_state import RunState, TerminalReason
from ..agent.runtime.factory import build_runtime
from ..agent.runtime.schemas import RuntimeSettings
from ..db.session import build_engine, build_session_factory, session_scope, init_db
@ -154,6 +154,37 @@ def _redirect_all_logging_to_file(log_file: Path) -> None:
lg.addHandler(file_handler)
def _build_db_url(args, episode: int, max_episodes: int) -> str:
"""Build SQLite DATABASE_URL, adding episode suffix when multi-episode."""
slug = args.model.replace("/", "_")
db_dir = Path("db")
db_dir.mkdir(exist_ok=True)
base = f"{args.config_name}_{args.seed}_{slug}"
if max_episodes > 1:
return f"sqlite:///{db_dir}/{base}.ep{episode}.db"
return f"sqlite:///{db_dir}/{base}.db"
def _read_scratchpad(db_factory, company_id) -> str:
"""Read scratchpad content from the current DB."""
from ..db.models.scratchpad import Scratchpad
with db_factory() as db:
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
return row.content if row else ""
def _write_scratchpad(db_factory, company_id, content: str) -> None:
"""Write scratchpad content into the current DB (upsert)."""
from ..db.models.scratchpad import Scratchpad
with db_factory() as db:
row = db.query(Scratchpad).filter(Scratchpad.company_id == company_id).first()
if row is None:
db.add(Scratchpad(company_id=company_id, content=content))
else:
row.content = content
db.flush()
def run_benchmark(args):
"""Run a full benchmark: migrate, seed, loop until terminal."""
logging.basicConfig(
@ -184,31 +215,11 @@ def run_benchmark(args):
_redirect_all_logging_to_file(log_file)
logger.info(
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy",
experiment_cfg.name, args.model, args.seed, horizon_years,
"YC-Bench starting: experiment=%s model=%s seed=%d horizon=%dy max_episodes=%d",
experiment_cfg.name, args.model, args.seed, horizon_years, args.max_episodes,
)
# 1. Build engine and create all tables
# If DATABASE_URL is not explicitly set, default to db/<config>_<seed>_<slug>.db
if not os.environ.get("DATABASE_URL"):
slug = args.model.replace("/", "_")
db_dir = Path("db")
db_dir.mkdir(exist_ok=True)
os.environ["DATABASE_URL"] = f"sqlite:///{db_dir}/{args.config_name}_{args.seed}_{slug}.db"
engine = build_engine()
init_db(engine)
factory = build_session_factory(engine)
@contextmanager
def db_factory():
with session_scope(factory) as session:
yield session
# 3. Init simulation using experiment world config
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
# 4. Build runtime settings from experiment AgentConfig
# Build runtime settings (shared across episodes)
agent_cfg = experiment_cfg.agent
settings = RuntimeSettings(
model=agent_cfg.model,
@ -222,7 +233,7 @@ def run_benchmark(args):
)
runtime = build_runtime(settings, command_executor=run_command)
# 5. Build run state
# Build run state (persists across episodes)
session_id = f"run-{args.seed}-{args.model}"
run_state = RunState(
session_id=session_id,
@ -231,66 +242,110 @@ def run_benchmark(args):
horizon_years=horizon_years,
)
# 6. Set up live dashboard (or not)
dashboard = None
on_turn_start = None
on_turn = None
if use_live:
from .dashboard import BenchmarkDashboard
dashboard = BenchmarkDashboard(
model=args.model,
seed=args.seed,
config_name=args.config_name,
db_factory=db_factory,
company_id=company_id,
)
def on_turn_start(turn_num):
dashboard.mark_turn_start(turn_num)
def on_turn(snapshot, rs, commands):
dashboard.update(snapshot, rs, commands)
# 7. Run agent loop
loop_cfg = experiment_cfg.loop
try:
max_episodes = args.max_episodes
carried_scratchpad = ""
for episode in range(1, max_episodes + 1):
logger.info("=== Episode %d / %d ===", episode, max_episodes)
# 1. Build engine for this episode's DB
db_url = _build_db_url(args, episode, max_episodes)
os.environ["DATABASE_URL"] = db_url
engine = build_engine()
init_db(engine)
factory = build_session_factory(engine)
@contextmanager
def db_factory(_factory=factory):
with session_scope(_factory) as session:
yield session
# 2. Init simulation
company_id = _init_simulation(db_factory, args, experiment_cfg, horizon_years)
# 3. Restore scratchpad from previous episode
if episode > 1 and carried_scratchpad:
_write_scratchpad(db_factory, company_id, carried_scratchpad)
logger.info("Restored scratchpad from episode %d (%d chars).", episode - 1, len(carried_scratchpad))
# 4. Set up live dashboard
dashboard = None
on_turn_start = None
on_turn = None
if use_live:
from .dashboard import BenchmarkDashboard
dashboard = BenchmarkDashboard(
model=args.model,
seed=args.seed,
config_name=args.config_name,
db_factory=db_factory,
company_id=company_id,
)
def on_turn_start(turn_num):
dashboard.mark_turn_start(turn_num)
def on_turn(snapshot, rs, commands):
dashboard.update(snapshot, rs, commands)
# 5. Run agent loop for this episode
try:
if dashboard is not None:
dashboard.start()
final_state = run_agent_loop(
runtime=runtime,
db_factory=db_factory,
company_id=company_id,
run_state=run_state,
command_executor=run_command,
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
max_turns=loop_cfg.max_turns,
on_turn_start=on_turn_start,
on_turn=on_turn,
episode=episode,
)
finally:
if dashboard is not None:
dashboard.stop()
if dashboard is not None:
dashboard.start()
dashboard.print_final_summary(final_state)
final_state = run_agent_loop(
runtime=runtime,
db_factory=db_factory,
company_id=company_id,
run_state=run_state,
command_executor=run_command,
auto_advance_after_turns=loop_cfg.auto_advance_after_turns,
max_turns=loop_cfg.max_turns,
on_turn_start=on_turn_start,
on_turn=on_turn,
)
finally:
if dashboard is not None:
dashboard.stop()
# 6. For multi-episode runs, snapshot this episode's data
if max_episodes > 1:
run_state.finish_episode()
# 8. Print final summary
if dashboard is not None:
dashboard.print_final_summary(final_state)
logger.info("Episode %d finished: reason=%s", episode, run_state.terminal_reason)
# 9. Save full rollout (with transcript) and print summary
# 7. If not bankruptcy, or last episode, stop
if run_state.terminal_reason != TerminalReason.BANKRUPTCY or episode == max_episodes:
break
# 8. Save scratchpad for next episode, then reset
carried_scratchpad = _read_scratchpad(db_factory, company_id)
logger.info("Carrying scratchpad to episode %d (%d chars).", episode + 1, len(carried_scratchpad))
# Clear runtime session (fresh conversation history)
runtime.clear_session(session_id)
run_state.reset_for_new_episode()
engine.dispose()
# 9. Save full rollout and print summary
rollout = final_state.full_rollout()
rollout["time_series"] = extract_time_series(db_factory, company_id)
summary = final_state.summary()
logger.info("Run complete: %s", json.dumps(summary, indent=2))
# Write full rollout (includes transcript with commands)
slug = args.model.replace("/", "_")
results_dir = Path("results")
results_dir.mkdir(exist_ok=True)
results_path = results_dir / f"yc_bench_result_{args.config_name}_{args.seed}_{slug}.json"
results_path.write_text(json.dumps(rollout, indent=2))
logger.info("Full rollout written to %s (%d turns)", results_path, len(rollout.get("transcript", [])))
logger.info("Full rollout written to %s", results_path)
return 0 if final_state.terminal_reason != "error" else 1