mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-03 17:53:17 +00:00
fix score
This commit is contained in:
parent
8a0e107806
commit
2ab8905d4f
1 changed files with 13 additions and 5 deletions
|
|
@ -17,6 +17,7 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from pydantic import Field # Added import for Field
|
||||
|
||||
# System prompt can be reused or adapted for instruction following tasks
|
||||
system_prompt = (
|
||||
|
|
@ -27,10 +28,17 @@ system_prompt = (
|
|||
)
|
||||
|
||||
|
||||
class IFConfig(BaseEnvConfig):
|
||||
dataset_name: str = Field("allenai/RLVR-IFeval", description="Default dataset name")
|
||||
dataset_config_name: Optional[str] = Field(None, description="Dataset config name, if any")
|
||||
test_set_ratio: float = Field(0.05, description="The ratio of the selected dataset for testing")
|
||||
|
||||
|
||||
class InstructionFollowingEnv(BaseEnv):
|
||||
env_config_cls = IFConfig # Added env_config_cls for IFConfig
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
config: IFConfig, # Changed BaseEnvConfig to IFConfig
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
|
|
@ -42,9 +50,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# self.completion_lengths = [] # Kept from SingleToolCallingEnv, assess utility
|
||||
|
||||
@classmethod
|
||||
def config_init(self) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
def config_init(self) -> Tuple[IFConfig, List[APIServerConfig]]: # Changed BaseEnvConfig to IFConfig
|
||||
# Configuration for the Instruction Following Environment
|
||||
env_config = BaseEnvConfig(
|
||||
env_config = IFConfig( # Changed BaseEnvConfig to IFConfig
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
|
|
@ -414,8 +422,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# Get score (1.0 for correct, 0.0 for incorrect from verifier)
|
||||
score_value = await self._get_score_from_verifier(model_response_text, func_name, args_for_verifier)
|
||||
|
||||
# Map to reward: 1.0 for correct, -1.0 for incorrect
|
||||
reward = 1.0 if score_value == 1.0 else -1.0
|
||||
# Map to reward: 1.0 for correct, 0 for incorrect
|
||||
reward = 1.0 if score_value == 1.0 else 0
|
||||
|
||||
# Tokenize the conversation for PPO training
|
||||
# Ensure full_trajectory_messages is a list of dicts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue