Merge commit '71e7a5ca27' into add-support-for-custom-api-servers

This commit is contained in:
dmahan93 2025-05-12 18:40:35 -05:00
commit 96be544228
45 changed files with 1605 additions and 494 deletions

View file

@ -1,9 +1,7 @@
import base64
import io
import os
import random
import re
import sys
import traceback
from typing import List, Optional, Tuple
@ -73,13 +71,8 @@ class OcrVqaEnv(BaseEnv):
history: GameHistory = (user_hist, assistant_hist)
to_score.append((history, gold, base64_image))
return to_score, to_backlog
async def postprocess_histories(
self, trajectories: List[GameHistory]
) -> ScoredDataGroup:
# No additional post-processing needed
pass
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def evaluate(self, *args, **kwargs):
# No custom evaluation
@ -172,29 +165,25 @@ class OcrVqaEnv(BaseEnv):
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
if not os.environ.get("OPENAI_API_KEY"):
print("ERROR: OPENAI_API_KEY environment variable is not set!")
sys.exit(1)
config = BaseEnvConfig(
wandb_name="ocr_vqa",
tokenizer_name="gpt2",
group_size=2,
use_wandb=False,
tokenizer_name="Qwen/Qwen2-VL-2B-Instruct",
group_size=8,
use_wandb=True,
max_num_workers=2,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=1,
steps_per_eval=10,
ensure_scores_are_not_same=False,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
)
server_configs = [
APIServerConfig(
model_name="gpt-4o",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
num_requests_for_eval=1,
model_name="Qwen/Qwen2-VL-2B-Instruct",
base_url="http://localhost:9001/v1",
api_key="x",
num_requests_for_eval=256,
),
]