fix score

This commit is contained in:
teknium1 2025-05-14 19:35:43 -07:00
parent 8a0e107806
commit 2ab8905d4f

View file

@ -17,6 +17,7 @@ from atroposlib.envs.base import (
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from pydantic import Field # Added import for Field
# System prompt can be reused or adapted for instruction following tasks
system_prompt = (
@ -27,10 +28,17 @@ system_prompt = (
)
class IFConfig(BaseEnvConfig):
dataset_name: str = Field("allenai/RLVR-IFeval", description="Default dataset name")
dataset_config_name: Optional[str] = Field(None, description="Dataset config name, if any")
test_set_ratio: float = Field(0.05, description="The ratio of the selected dataset for testing")
class InstructionFollowingEnv(BaseEnv):
env_config_cls = IFConfig # Added env_config_cls for IFConfig
def __init__(
self,
config: BaseEnvConfig,
config: IFConfig, # Changed BaseEnvConfig to IFConfig
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
@ -42,9 +50,9 @@ class InstructionFollowingEnv(BaseEnv):
# self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility
@classmethod
def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
def config_init(self) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig
# Configuration for the Instruction Following Environment
env_config = BaseEnvConfig(
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
group_size=16,
use_wandb=True,
@ -414,8 +422,8 @@ class InstructionFollowingEnv(BaseEnv):
# Get score (1.0 for correct, 0.0 for incorrect from verifier)
score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier)
# Map to reward: 1.0 for correct, -1.0 for incorrect
reward = 1.0 if score_value == 1.0 else -1.0
# Map to reward: 1.0 for correct, 0 for incorrect
reward = 1.0 if score_value == 1.0 else 0
# Tokenize the conversation for PPO training
# Ensure full_trajectory_messages is a list of dicts