atropos/environments/community/protein_design/protein_env.py
Shannon Sands ec2b6f093d linting
2025-05-27 12:29:10 +10:00

1045 lines
43 KiB
Python

import json
import logging
import os
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict
import wandb
import yaml
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
Item,
ScoredDataGroup,
)
from atroposlib.type_definitions import Message
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
from .prompts import SYSTEM_PROMPT, construct_user_prompt
from .tool_definitions import ALL_TOOLS_LIST
from .tool_executor import ToolExecutor
logger = logging.getLogger(__name__)
load_dotenv()
def load_target_binder_pairs(
dataset_name: str, target_col: str, binder_col: str, split: str = "train"
) -> Dataset:
"""
Loads and transforms a Hugging Face dataset to contain only 'target' and 'binder' columns.
Args:
dataset_name (str): Hugging Face dataset identifier.
target_col (str): Name of the column containing target protein sequences.
binder_col (str): Name of the column containing binder sequences.
split (str): Dataset split to load.
Returns:
Dataset: Hugging Face Dataset object with columns ['target', 'binder'].
"""
ds = load_dataset(dataset_name, split=split)
logger.info(f"Loaded dataset with columns: {ds.column_names}")
actual_target_col = "receptor"
actual_binder_col = "peptide"
try:
ds = ds.rename_columns(
{actual_target_col: "target", actual_binder_col: "binder"}
)
ds = ds.remove_columns(
[col for col in ds.column_names if col not in {"target", "binder"}]
)
except ValueError as e:
logger.error(f"Error renaming columns: {e}")
logger.error(f"Available columns: {ds.column_names}")
if (
actual_target_col in ds.column_names
and actual_binder_col in ds.column_names
):
ds = ds.select_columns([actual_target_col, actual_binder_col])
ds = ds.rename_columns(
{actual_target_col: "target", actual_binder_col: "binder"}
)
else:
logger.error(
f"Could not find expected columns in dataset. Available columns: {ds.column_names}"
)
raise ValueError(
f"Dataset {dataset_name} doesn't have the expected columns. Please check your dataset configuration."
)
return ds
class BinderRow(TypedDict):
target: str
binder: str
class BinderBenchConfig(BaseEnvConfig):
nim_api_key: Optional[str] = Field(None, description="NVIDIA NIM API key")
nim_api_base_url: str = Field(
"https://health.api.nvidia.com/v1", description="NIM API base URL"
)
api_timeout: int = Field(1800, description="Timeout for NIM API calls")
polling_interval: int = Field(30, description="Polling interval for NIM jobs")
output_dir: str = Field(
default=str(Path(__file__).parent / "outputs"),
description="Directory to save PDBs, etc.",
)
debug_protein_design_calls: bool = Field(
False,
description="Enable debug mode for NIM protein API calls, returning mock data.",
)
max_retries_per_internal_step: int = Field(
100,
description="Max retries for a failed tool call within a workflow step (0 means no retries).",
)
dataset_name: str = Field(
"ronig/protein_binding_sequences", description="Dataset for target sequences"
)
target_col: str = Field(
"receptor", description="Target column name (actual column in the dataset)"
)
binder_col: str = Field(
"peptide", description="Binder column name (actual column in the dataset)"
)
class BinderBenchEnv(BaseEnv):
name = "binderbench"
env_config_cls = BinderBenchConfig
def __init__(
self,
config: BinderBenchConfig,
server_configs: List[APIServerConfig],
slurm=False,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.config: BinderBenchConfig
self.process_mode = False
self.tools = ALL_TOOLS_LIST
self.output_dir = Path(self.config.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.episodes_state = {}
self.completed_episode_metrics: List[Dict] = []
self.rollouts_for_wandb = []
self.tool_executor = ToolExecutor(
nim_api_key=self.config.nim_api_key,
api_timeout=self.config.api_timeout,
polling_interval=self.config.polling_interval,
output_dir=self.output_dir,
debug_protein_design_calls=self.config.debug_protein_design_calls,
)
async def _execute_tool(
self, tool_name: str, args: Dict, workflow_state: Dict
) -> Dict:
"""Delegates tool execution and then updates workflow_state based on the result."""
execution_result_package = await self.tool_executor.dispatch_tool_call(
tool_name, args, workflow_state
)
tool_output = execution_result_package.get("tool_output", {})
state_updates = execution_result_package.get("state_updates", {})
if state_updates:
workflow_state.update(state_updates)
logger.debug(
f"Workflow {workflow_state['item_id']}: State updated with keys: {list(state_updates.keys())}"
)
return tool_output
@classmethod
def config_init(cls) -> Tuple[BinderBenchConfig, List[APIServerConfig]]:
default_yaml_path = (
Path(__file__).parent / "configs" / "binderbench_default.yaml"
)
yaml_config_values = {}
if default_yaml_path.exists():
with open(default_yaml_path, "r") as f:
yaml_config_values = yaml.safe_load(f) or {}
env_config = BinderBenchConfig(
use_wandb=True,
wandb_name=cls.name,
nim_api_key=os.environ.get("NVIDIA_NIM_API_KEY"),
debug_protein_design_calls=yaml_config_values.get(
"debug_protein_design_calls",
bool(os.environ.get("DEBUG_PROTEIN_DESIGN_CALLS", False)),
),
)
llm_api_key = os.environ.get("OPENAI_API_KEY")
llm_base_url = os.environ.get("OPENAI_API_BASE")
server_configs = [
APIServerConfig(
model_name=os.environ.get("DEFAULT_LLM_MODEL", "gpt-4-turbo"),
api_key=llm_api_key,
base_url=llm_base_url,
)
]
return env_config, server_configs
async def setup(self):
self.iter = 0
self.train = load_target_binder_pairs(
dataset_name=self.config.dataset_name,
target_col=self.config.target_col,
binder_col=self.config.binder_col,
)
logger.info(f"Loaded {len(self.train)} target-binder pairs for {self.name}.")
if not self.config.nim_api_key:
self.config.nim_api_key = os.environ.get("NVIDIA_NIM_API_KEY")
if not self.config.nim_api_key:
logger.warning(
"NVIDIA NIM API key not set. Protein design functions may not work properly."
)
def _initialize_workflow_state(
self, item_id: str, target_sequence: str, ground_truth_binder: Optional[str]
) -> Dict:
"""Initializes or resets the state for a new workflow."""
return {
"item_id": item_id,
"current_internal_step": 0,
"target_sequence": target_sequence,
"ground_truth_binder_sequence": ground_truth_binder,
"target_pdb_content": None,
"target_chain_details": None,
"binder_backbone_pdb_content": None,
"designed_binder_sequence": None,
"complex_pdb_content_path": None,
"af2_multimer_plddt": 0.0,
"target_structure_predicted": False,
"binder_backbone_designed": False,
"binder_sequence_designed": False,
"complex_evaluated": False,
"workflow_complete_flag": False,
"last_tool_success": True,
"cumulative_reward": 0.0,
"turn_messages_history": [],
"retry_count_this_internal_step": 0,
"previous_tool_error_message": None,
}
async def get_next_item(self) -> Item:
"""
Provides the initial information for a new protein design workflow.
Returns an Item tuple: (item_id, initial_target_sequence_info)
"""
raw_item: BinderRow = self.train[self.iter % len(self.train)]
self.iter += 1
item_id = str(uuid.uuid4())
target_sequence = raw_item["target"]
ground_truth_binder = raw_item.get("binder")
self.episodes_state[item_id] = self._initialize_workflow_state(
item_id, target_sequence, ground_truth_binder
)
return item_id
def reset_state(self, item_id: str) -> dict:
"""Retrieves the workflow state for the given item_id."""
if item_id in self.episodes_state:
return self.episodes_state[item_id]
else:
logger.error(
f"No state found for item_id {item_id}. Creating a default state."
)
return self._initialize_workflow_state(item_id, "", None)
async def collect_trajectories(
self, item_id: str
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
workflow_state = self.episodes_state.get(item_id)
if not workflow_state:
logger.error(f"Workflow state for item_id {item_id} not found. Skipping.")
return None, []
if workflow_state.get("workflow_complete_flag"):
logger.info(f"Workflow for {item_id} already marked complete. Skipping.")
return None, []
is_processing_mode = getattr(self, "process_mode", False) # Check the flag
if is_processing_mode:
all_turns_data_for_jsonl = []
MAX_INTERNAL_STEPS = 4
while workflow_state[
"current_internal_step"
] < MAX_INTERNAL_STEPS and not workflow_state.get("workflow_complete_flag"):
current_turn_messages: List[Message] = []
user_prompt_str = construct_user_prompt(workflow_state)
current_turn_messages.append(
Message(role="system", content=SYSTEM_PROMPT)
)
current_turn_messages.append(
Message(role="user", content=user_prompt_str)
)
llm_response = await self.server.chat_completion(
messages=current_turn_messages,
tools=self.tools,
tool_choice="auto",
n=1,
max_tokens=self.config.max_token_length,
temperature=0.5,
)
assistant_message_obj = llm_response.choices[0].message
assistant_content = assistant_message_obj.content or ""
assistant_tool_calls = []
if (
hasattr(assistant_message_obj, "tool_calls")
and assistant_message_obj.tool_calls
):
assistant_tool_calls = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in assistant_message_obj.tool_calls
]
current_turn_messages.append(
Message(
role="assistant",
content=assistant_content,
tool_calls=(
assistant_tool_calls if assistant_tool_calls else None
),
)
)
tool_error_for_retry_prompt = None
if assistant_tool_calls:
tool_call_request = assistant_tool_calls[0]
tool_name = tool_call_request["function"]["name"]
try:
tool_args = json.loads(
tool_call_request["function"]["arguments"]
)
tool_result = await self._execute_tool(
tool_name, tool_args, workflow_state
)
current_turn_messages.append(
Message(
role="tool",
tool_call_id=tool_call_request["id"],
name=tool_name,
content=json.dumps(tool_result),
)
)
workflow_state["last_tool_success"] = tool_result.get(
"success", False
)
if not workflow_state["last_tool_success"]:
tool_error_for_retry_prompt = tool_result.get(
"error", "Tool execution failed."
)
except Exception as e:
error_msg = f"Error processing tool {tool_name}: {str(e)}"
current_turn_messages.append(
Message(
role="tool",
tool_call_id=tool_call_request["id"],
name=tool_name,
content=error_msg,
)
)
workflow_state["last_tool_success"] = False
tool_error_for_retry_prompt = error_msg
else:
workflow_state["last_tool_success"] = False
expected_tool_name = {
0: "AF2",
1: "RFD",
2: "PMPNN",
3: "AF2M",
}.get(workflow_state["current_internal_step"], "a tool")
tool_error_for_retry_prompt = (
f"No tool was called, but {expected_tool_name} was expected."
)
workflow_state["previous_tool_error_message"] = (
tool_error_for_retry_prompt
)
turn_score_details = self._score_trajectory(
current_turn_messages, workflow_state
)
current_turn_reward = turn_score_details.get("overall_reward", 0.0)
workflow_state["cumulative_reward"] += current_turn_reward
tokenization_result = tokenize_for_trainer(
self.tokenizer, current_turn_messages, include_messages=False
)
all_turns_data_for_jsonl.append(
{
"tokens_this_turn": tokenization_result["tokens"],
"masks_this_turn": tokenization_result["masks"],
"score_this_turn": current_turn_reward,
"messages_this_turn": current_turn_messages.copy(),
"overrides_this_turn": turn_score_details.copy(),
}
)
if workflow_state["last_tool_success"]:
workflow_state["current_internal_step"] += 1
workflow_state["retry_count_this_internal_step"] = 0
workflow_state["previous_tool_error_message"] = None
else:
if workflow_state["current_internal_step"] <= 3:
workflow_state["retry_count_this_internal_step"] += 1
if (
workflow_state["retry_count_this_internal_step"]
> self.config.max_retries_per_internal_step
):
logger.warning(
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
f"Max retries ({self.config.max_retries_per_internal_step}) reached. "
f"Terminating workflow for this item."
)
workflow_state["workflow_complete_flag"] = True
break
else:
logger.info(
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
f"Failed, attempt {workflow_state['retry_count_this_internal_step']}. "
f"Retrying same step."
)
else:
logger.warning(
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
f"Failure at non-retryable step. Terminating workflow."
)
workflow_state["workflow_complete_flag"] = True
break
if workflow_state["current_internal_step"] >= MAX_INTERNAL_STEPS:
workflow_state["workflow_complete_flag"] = True
logger.info(
f"Workflow {item_id}: All internal steps completed successfully."
)
if not all_turns_data_for_jsonl:
logger.warning(
f"Workflow {item_id} in process mode: No turn data collected."
)
return None, []
html_compatible_messages: List[str] = []
html_compatible_scores: List[float] = []
overrides_for_jsonl: List[Dict[str, Any]] = []
for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl):
turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"]
if turn_data.get("messages_this_turn"):
for msg_obj in turn_data["messages_this_turn"]:
content_str = str(msg_obj.get("content", "[No Content]"))
if msg_obj.get("tool_calls"):
try:
tool_calls_str = json.dumps(
msg_obj.get("tool_calls"), indent=2
)
content_str += f"\nTool Calls:\n{tool_calls_str}"
except TypeError: # Handle non-serializable content if any
content_str += (
"\nTool Calls: [Error serializing tool_calls]"
)
turn_str_parts.append(
f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}"
)
else:
turn_str_parts.append("No messages recorded for this turn.")
html_compatible_messages.append("\n\n".join(turn_str_parts))
turn_score = turn_data.get("overrides_this_turn", {}).get(
"overall_reward", 0.0
)
html_compatible_scores.append(turn_score)
overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {}))
final_workflow_reward = workflow_state.get("cumulative_reward", 0.0)
if workflow_state.get("complex_evaluated") and workflow_state.get(
"last_tool_success"
):
final_workflow_reward = (
all_turns_data_for_jsonl[-1]
.get("overrides_this_turn", {})
.get("overall_reward", 0.0)
)
all_tokens_per_turn = [
turn_data["tokens_this_turn"]
for turn_data in all_turns_data_for_jsonl
if turn_data.get("tokens_this_turn")
]
all_masks_per_turn = [
turn_data["masks_this_turn"]
for turn_data in all_turns_data_for_jsonl
if turn_data.get("masks_this_turn")
]
if len(all_tokens_per_turn) != len(html_compatible_messages):
logger.error(
f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) "
f"and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic."
)
if all_turns_data_for_jsonl and all_tokens_per_turn:
last_tokens = all_tokens_per_turn[-1]
last_masks = all_masks_per_turn[-1]
all_tokens_per_turn = [last_tokens] * len(html_compatible_messages)
all_masks_per_turn = [last_masks] * len(html_compatible_messages)
else:
all_tokens_per_turn = [[] for _ in html_compatible_messages]
all_masks_per_turn = [[] for _ in html_compatible_messages]
process_mode_scored_data = ScoredDataGroup(
tokens=all_tokens_per_turn,
masks=all_masks_per_turn,
messages=html_compatible_messages,
scores=html_compatible_scores,
overrides=overrides_for_jsonl,
group_overrides={
"group_size": len(html_compatible_messages),
"item_id": item_id,
"is_process_mode_full_workflow": True,
"final_score_for_workflow": final_workflow_reward,
"target_sequence": workflow_state.get("target_sequence", "N/A"),
"designed_binder_sequence": workflow_state.get(
"designed_binder_sequence", "N/A"
),
"final_plddt": workflow_state.get("af2_multimer_plddt", 0.0),
},
)
await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy())
self.completed_episode_metrics.append(workflow_state.copy())
if item_id in self.episodes_state:
del self.episodes_state[item_id]
return process_mode_scored_data, []
else:
current_turn_messages_serve: List[Message] = []
user_prompt_str_serve = construct_user_prompt(workflow_state)
current_turn_messages_serve.append(
Message(role="system", content=SYSTEM_PROMPT)
)
current_turn_messages_serve.append(
Message(role="user", content=user_prompt_str_serve)
)
llm_response_serve = await self.server.chat_completion(
messages=current_turn_messages_serve,
tools=self.tools,
tool_choice="auto",
n=1,
max_tokens=self.config.max_token_length,
temperature=0.5,
)
assistant_message_obj_serve = llm_response_serve.choices[0].message
assistant_content_serve = assistant_message_obj_serve.content or ""
assistant_tool_calls_serve = []
if (
hasattr(assistant_message_obj_serve, "tool_calls")
and assistant_message_obj_serve.tool_calls
):
assistant_tool_calls_serve = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in assistant_message_obj_serve.tool_calls
]
current_turn_messages_serve.append(
Message(
role="assistant",
content=assistant_content_serve,
tool_calls=(
assistant_tool_calls_serve
if assistant_tool_calls_serve
else None
),
)
)
tool_error_for_retry_prompt_serve = None
if assistant_tool_calls_serve:
tool_call_request_serve = assistant_tool_calls_serve[0]
tool_name_serve = tool_call_request_serve["function"]["name"]
try:
tool_args_json_str = tool_call_request_serve["function"][
"arguments"
]
tool_args_serve = json.loads(tool_args_json_str)
tool_result_serve = await self._execute_tool(
tool_name_serve, tool_args_serve, workflow_state
)
current_turn_messages_serve.append(
Message(
role="tool",
tool_call_id=tool_call_request_serve["id"],
name=tool_name_serve,
content=json.dumps(tool_result_serve),
)
)
workflow_state["last_tool_success"] = tool_result_serve.get(
"success", False
)
if not workflow_state["last_tool_success"]:
tool_error_for_retry_prompt_serve = tool_result_serve.get(
"error", "Tool execution failed."
)
except Exception as e:
error_msg_serve = (
f"Error processing tool {tool_name_serve}: {str(e)}"
)
current_turn_messages_serve.append(
Message(
role="tool",
tool_call_id=tool_call_request_serve["id"],
name=tool_name_serve,
content=error_msg_serve,
)
)
workflow_state["last_tool_success"] = False
tool_error_for_retry_prompt_serve = error_msg_serve
else:
workflow_state["last_tool_success"] = False
expected_tool_name_serve = {
0: "AF2",
1: "RFD",
2: "PMPNN",
3: "AF2M",
}.get(workflow_state["current_internal_step"], "a tool")
tool_error_for_retry_prompt_serve = (
f"No tool was called, but {expected_tool_name_serve} was expected."
)
workflow_state["previous_tool_error_message"] = (
tool_error_for_retry_prompt_serve
)
turn_score_details_serve = self._score_trajectory(
current_turn_messages_serve, workflow_state
)
current_turn_reward_serve = turn_score_details_serve.get(
"overall_reward", 0.0
)
workflow_state["cumulative_reward"] += current_turn_reward_serve
workflow_state["turn_messages_history"].append(
current_turn_messages_serve.copy()
)
tokenization_result_serve = tokenize_for_trainer(
self.tokenizer,
current_turn_messages_serve,
include_messages=self.config.include_messages,
)
scored_data_serve = ScoredDataGroup(
tokens=[tokenization_result_serve["tokens"]],
masks=[tokenization_result_serve["masks"]],
scores=[current_turn_reward_serve],
messages=(
[current_turn_messages_serve]
if self.config.include_messages
else None
),
overrides=[turn_score_details_serve],
group_overrides={"group_size": 1},
)
backlog_items_serve = []
if workflow_state["last_tool_success"]:
workflow_state["current_internal_step"] += 1
workflow_state["retry_count_this_internal_step"] = 0
workflow_state["previous_tool_error_message"] = None
else:
if workflow_state["current_internal_step"] <= 3:
workflow_state["retry_count_this_internal_step"] += 1
if (
workflow_state["retry_count_this_internal_step"]
> self.config.max_retries_per_internal_step
):
logger.warning(
f"Workflow {item_id}, Step {workflow_state['current_internal_step']} "
f"(Serve Mode): Max retries reached. Terminating."
)
workflow_state["workflow_complete_flag"] = True
else:
logger.warning(
f"Workflow {item_id}, Step {workflow_state['current_internal_step']} "
f"(Serve Mode): Failure at non-retryable step. Terminating."
)
workflow_state["workflow_complete_flag"] = True
if workflow_state["current_internal_step"] < 4 and not workflow_state.get(
"workflow_complete_flag"
):
should_add_to_backlog = False
if workflow_state["last_tool_success"]:
should_add_to_backlog = True
elif (
workflow_state["current_internal_step"] <= 3
and workflow_state["retry_count_this_internal_step"]
<= self.config.max_retries_per_internal_step
):
should_add_to_backlog = True
if should_add_to_backlog:
backlog_items_serve.append(item_id)
else:
workflow_state["workflow_complete_flag"] = True
logger.info(
f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. "
f"Internal step: {workflow_state['current_internal_step']}"
)
if workflow_state.get("workflow_complete_flag"):
if item_id in self.episodes_state:
await self.add_rollouts_for_wandb(
data_for_log=self.episodes_state[item_id].copy()
)
self.completed_episode_metrics.append(
self.episodes_state[item_id].copy()
)
del self.episodes_state[item_id]
return scored_data_serve, backlog_items_serve
def _score_trajectory(
self, turn_messages: List[Message], workflow_state: Dict
) -> Dict[str, float]:
"""
Scores a single turn's trajectory based on the specified reward logic.
- Steps 0-2: Format reward (0.2 for correct & successful tool call, 0 otherwise).
- Step 3 (AF2-Multimer): Reward based on pLDDT.
"""
detailed_scores = {
"overall_reward": 0.0,
"raw_plddt": 0.0,
}
internal_step = workflow_state.get("current_internal_step")
last_tool_success = workflow_state.get("last_tool_success", False)
assistant_msg_dict = next(
(m for m in reversed(turn_messages) if m.get("role") == "assistant"), None
)
expected_tool_for_step = {
0: "predict_target_structure_alphafold2",
1: "design_binder_backbone_rfdiffusion",
2: "design_binder_sequence_proteinmpnn",
3: "evaluate_binder_complex_alphafold2_multimer",
}.get(internal_step)
called_tool_name = None
if assistant_msg_dict and assistant_msg_dict.get("tool_calls"):
tool_calls_list = assistant_msg_dict.get("tool_calls")
if (
tool_calls_list
and isinstance(tool_calls_list, list)
and len(tool_calls_list) > 0
):
function_call_dict = tool_calls_list[0].get("function")
if function_call_dict and isinstance(function_call_dict, dict):
called_tool_name = function_call_dict.get("name")
if internal_step < 3:
if last_tool_success and called_tool_name == expected_tool_for_step:
detailed_scores["overall_reward"] = 0.2
logger.info(
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
f"Correct tool '{called_tool_name}' used successfully. Reward: 0.2"
)
else:
detailed_scores["overall_reward"] = 0.0
if not last_tool_success and called_tool_name:
logger.warning(
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
f"Tool '{called_tool_name}' execution failed. Reward: 0.0"
)
elif called_tool_name != expected_tool_for_step:
logger.warning(
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
f"Incorrect tool '{called_tool_name}' used (expected '{expected_tool_for_step}'). "
f"Reward: 0.0"
)
elif not called_tool_name and expected_tool_for_step:
logger.warning(
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
f"No tool called, but expected '{expected_tool_for_step}'. Reward: 0.0"
)
elif internal_step == 3:
if (
workflow_state.get("complex_evaluated")
and last_tool_success
and called_tool_name == expected_tool_for_step
):
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
detailed_scores["raw_plddt"] = plddt
if plddt > 90.0:
detailed_scores["overall_reward"] = 1.0
elif plddt > 50.0:
detailed_scores["overall_reward"] = 0.0 + (plddt - 50.0) * (
1.0 - 0.0
) / (90.0 - 50.0)
detailed_scores["overall_reward"] = max(
0.0, min(detailed_scores["overall_reward"], 1.0)
)
else:
detailed_scores["overall_reward"] = 0.0
logger.info(
f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): "
f"pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}"
)
else:
detailed_scores["overall_reward"] = 0.0
logger.warning(
f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): "
f"Evaluation failed or wrong tool. Reward: 0.0. Last tool success: {last_tool_success}, "
f"Called: {called_tool_name}"
)
else:
logger.error(
f"Workflow {workflow_state['item_id']}: "
f"Invalid internal_step {internal_step} in scoring."
)
detailed_scores["overall_reward"] = -1.0
return detailed_scores
async def postprocess_histories(
self, trajectories: Optional[ScoredDataGroup]
) -> Optional[ScoredDataGroup]:
"""
Post-processes a ScoredDataGroup for a single turn.
Can be used for final adjustments or filtering if needed.
"""
return trajectories
async def evaluate(self, *args, **kwargs):
"""
Evaluate the environment's performance.
This method is called periodically by the BaseEnv.env_manager.
For BinderBenchEnv, it will aggregate metrics from completed workflows.
"""
logger.info(f"Running evaluation for {self.name}...")
if not self.completed_episode_metrics:
logger.info("No completed episodes to evaluate since last evaluation.")
self.eval_metrics = (
[]
) # Ensure eval_metrics is an empty list if no new data
if self.config.use_wandb:
await self.wandb_log({}) # Log that no eval data was present this cycle
return
plddts, cumulative_rewards, workflow_successes = [], [], []
current_eval_episodes = self.completed_episode_metrics.copy()
for ep_state in current_eval_episodes:
if ep_state.get("complex_evaluated") and ep_state.get("last_tool_success"):
plddts.append(ep_state.get("af2_multimer_plddt", 0.0))
workflow_successes.append(1.0)
else:
workflow_successes.append(0.0)
cumulative_rewards.append(ep_state.get("cumulative_reward", 0.0))
self.eval_metrics = [] # Reset class member for current evaluation results
if plddts:
self.eval_metrics.append(("eval/avg_plddt", sum(plddts) / len(plddts)))
if cumulative_rewards:
self.eval_metrics.append(
(
"eval/avg_cumulative_reward",
sum(cumulative_rewards) / len(cumulative_rewards),
)
)
if workflow_successes:
self.eval_metrics.append(
(
"eval/workflow_success_rate",
sum(workflow_successes) / len(workflow_successes),
)
)
logger.info(f"Evaluation complete. Calculated metrics: {self.eval_metrics}")
if self.config.use_wandb:
await self.wandb_log({})
self.completed_episode_metrics.clear()
async def add_rollouts_for_wandb(
self,
scored_data_group: ScoredDataGroup = None,
item_id: Item = None,
data_for_log: Dict = None,
):
"""Adds a workflow summary to the wandb rollout buffer.
This method has two modes of operation:
1. Direct logging with workflow_state (preferred for detailed logging):
- Called from within collect_trajectories with data_for_log=workflow_state.copy()
- This provides maximum detail for logging
2. BaseEnv compatibility mode:
- Called from BaseEnv.handle_send_to_api with scored_data_group and item_id
- Used automatically by the framework
- May have less detail if workflow_state was already deleted
Args:
scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv)
item_id: The item identifier, which is the key to our episodes_state (from BaseEnv)
data_for_log: Direct workflow state to log (our custom parameter for direct logging)
"""
if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"):
if not hasattr(self, "rollouts_for_wandb"):
self.rollouts_for_wandb = []
workflow_state = None
if data_for_log is not None and isinstance(data_for_log, dict):
workflow_state = data_for_log
if item_id is None and "item_id" in workflow_state:
item_id = workflow_state["item_id"]
elif item_id is not None and item_id in self.episodes_state:
workflow_state = self.episodes_state[item_id]
if workflow_state is None:
logger.debug(
f"No workflow_state available for WandB logging (item_id={item_id})"
)
return
target_seq = workflow_state.get("target_sequence", "N/A")
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
cumulative_reward = workflow_state.get("cumulative_reward", 0.0)
last_turn_messages_str = "No messages."
try:
if (
workflow_state.get("turn_messages_history")
and len(workflow_state["turn_messages_history"]) > 0
):
last_turn_convo = workflow_state["turn_messages_history"][-1]
last_turn_messages_str = "\n---\n".join(
[
f"{msg.get('role', 'unknown')}: {str(msg.get('content', ''))[:200]}..."
for msg in last_turn_convo
]
)
except Exception as e:
logger.error(f"Error processing messages for WandB: {e}")
last_turn_messages_str = "Error processing messages"
target_preview = (
target_seq[:30] + "..."
if isinstance(target_seq, str) and len(target_seq) > 30
else target_seq
)
designed_binder_data = workflow_state.get("designed_binder_sequence", "N/A")
binder_preview = "N/A"
if isinstance(designed_binder_data, list) and designed_binder_data:
first_chain_seq = str(designed_binder_data[0])
preview_text = (
first_chain_seq[:30] + "..."
if len(first_chain_seq) > 30
else first_chain_seq
)
if len(designed_binder_data) > 1:
binder_preview = f"{len(designed_binder_data)} chains: {preview_text}"
else:
binder_preview = preview_text
elif isinstance(designed_binder_data, str) and designed_binder_data != "N/A":
binder_preview = (
designed_binder_data[:30] + "..."
if len(designed_binder_data) > 30
else designed_binder_data
)
if item_id is None:
item_id = workflow_state.get("item_id", "unknown-id")
self.rollouts_for_wandb.append(
(
str(item_id),
target_preview,
binder_preview,
f"{plddt:.2f}",
f"{cumulative_reward:.3f}",
last_turn_messages_str,
)
)
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
"""Creates a wandb.Table from the buffered rollouts."""
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
columns = [
"Item ID",
"Target (Preview)",
"Designed Binder (Preview)",
"Final pLDDT",
"Cumulative Reward",
"Last Turn Messages",
]
table = wandb.Table(columns=columns)
for rollout_tuple in self.rollouts_for_wandb:
table.add_data(*rollout_tuple)
table_key = f"env_rollouts/{self.wandb_prepend}/completed_workflows"
if self.wandb_prepend is None and hasattr(self, "name"):
table_key = f"env_rollouts/{self.name}/completed_workflows"
wandb_metrics[table_key] = table
self.rollouts_for_wandb.clear()
return wandb_metrics
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
wandb_metrics = await self.create_rollout_table(wandb_metrics)
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
BinderBenchEnv.cli()