atropos/environments
2025-05-14 13:57:20 -07:00
..
configs hide complicated openai config override behavior somewhere else 2025-05-03 14:18:50 -07:00
dataset_environment run pre-commit on all files 2025-05-09 09:54:20 -05:00
game_environments/gymnasium Updated README 2025-05-14 13:57:20 -07:00
multimodal_dpo run pre-commit on all files 2025-05-09 09:54:20 -05:00
fundamental_prediction_environment.py run pre-commit on all files 2025-05-09 09:54:20 -05:00
gsm8k_server.py first commit 2025-04-29 12:10:10 -07:00
gym_taxi.py add gym taxi env 2025-05-09 19:05:01 -05:00
math_server.py run pre-commit on all files 2025-05-09 09:54:20 -05:00
math_server_zero.py fix olympiadbench due to upstream changes 2025-05-09 09:41:10 -05:00
mcqa_thinking_env.py first commit 2025-04-29 12:10:10 -07:00
README.md first commit 2025-04-29 12:10:10 -07:00
rlaif_server.py first commit 2025-04-29 12:10:10 -07:00
tool_calling_server.py first commit 2025-04-29 12:10:10 -07:00

Environments

This directory contains various environments for training and evaluating language models on different tasks. Each environment implements a specific task with its own input format, reward function, and evaluation metrics.

Available Environments


MCQA Thinking Environment (mcqa_thinking_env.py)

Multiple Choice Question Answering environment that requires models to think through problems systematically.

Input Format:

  • Questions from the MMLU (Massive Multitask Language Understanding) dataset
  • Each item contains:
    • prompt: The question text
    • answer: Index of correct answer
    • ground_truth: Letter (A, B, C, D) of correct answer
    • options: List of possible answers

System Prompt:

You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.

Reward Function:

  • Score of 1.0 if the model's answer matches the ground truth letter
  • Score of 0.0 if incorrect or invalid response (multiple think tags, malformed thinking sections)
  • Length penalty applied if all responses are correct:
    • No penalty for responses under 50% of max token length
    • Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
    • Returns None if all scores are identical (no learning signal)

GSM8K Environment (gsm8k_server.py)

Mathematical reasoning environment using the GSM8K dataset.

Input Format:

  • Questions from GSM8K dataset
  • Each item contains:
    • question: The math problem
    • answer: The numerical answer

System Prompt:

You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.

You are allocated a maximum of 2048 tokens, please strive to use less.

You will then provide your answer like this: \boxed{your answer here}
It is important that you provide your answer in the correct format.
If you do not, you will not receive credit for your answer.
So please end your answer with \boxed{your answer here}

Reward Function:

  • Score of 1.0 if the model's answer matches the ground truth (using LaTeX verification)
  • Score of 0.0 if incorrect or if ground truth is not parseable
  • Length penalty applied if all responses are correct:
    • No penalty for responses under 50% of max token length
    • Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
    • Returns None if all scores are identical (no learning signal)

Tool Calling Environment (tool_calling_server.py)

Environment for training models to make function calls in a structured format.

Input Format:

  • Conversations from ShareGPT-Hermes function call dataset
  • Each item contains:
    • conversations: List of messages with roles (system, human, gpt)
    • Expected tool calls in JSON format

System Prompt:

You are a deep thinking AI, you may use extremely long chains of thought to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help come to a correct solution prior to answering. You should enclose your thoughts and internal monologue inside <think> </think> tags, and then provide your solution or response to the problem.

Reward Function:

  • Score of 1.0 if all expected tool calls are present and match exactly (including nested JSON fields)
  • Score of 0.0 if any tool calls are missing, incorrect, or malformed
  • Length penalty applied if all responses are correct:
    • No penalty for responses under 50% of max token length
    • Linear penalty scaling from 1.0 down to 0.0 for responses between 50% and 100% of max length
    • Returns None if all scores are identical (no learning signal)

Common Features

All environments share these common features:

  1. Training/Test Split:

    • 98% training, 2% test split
    • Random shuffling with fixed seed (42)
  2. Metrics Tracking:

    • Percent correct buffer
    • Completion lengths
    • Wandb integration for visualization
    • Rollout tracking
  3. Token Management:

    • Maximum token length limits
    • Token length statistics tracking
    • Length penalty for excessive responses
  4. Evaluation:

    • Separate evaluation on test set
    • Comprehensive metrics logging
    • Support for multiple model completions per prompt

Usage

Each environment can be initialized with:

  • config: BaseEnvConfig object
  • server_configs: List of OpenAI API configurations
  • slurm: Boolean for distributed training
  • testing: Boolean for testing mode

The environments follow a common interface with methods for:

  • setup(): Loading and preparing datasets
  • get_next_item(): Retrieving next training item
  • collect_trajectories(): Generating model responses
  • score(): Computing rewards
  • evaluate(): Running evaluation on test set
  • wandb_log(): Logging metrics to Weights & Biases