mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
1045 lines
43 KiB
Python
1045 lines
43 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple, TypedDict
|
|
|
|
import wandb
|
|
import yaml
|
|
from datasets import Dataset, load_dataset
|
|
from dotenv import load_dotenv
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
Item,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.type_definitions import Message
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
from .prompts import SYSTEM_PROMPT, construct_user_prompt
|
|
from .tool_definitions import ALL_TOOLS_LIST
|
|
from .tool_executor import ToolExecutor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
load_dotenv()
|
|
|
|
|
|
def load_target_binder_pairs(
|
|
dataset_name: str, target_col: str, binder_col: str, split: str = "train"
|
|
) -> Dataset:
|
|
"""
|
|
Loads and transforms a Hugging Face dataset to contain only 'target' and 'binder' columns.
|
|
|
|
Args:
|
|
dataset_name (str): Hugging Face dataset identifier.
|
|
target_col (str): Name of the column containing target protein sequences.
|
|
binder_col (str): Name of the column containing binder sequences.
|
|
split (str): Dataset split to load.
|
|
|
|
Returns:
|
|
Dataset: Hugging Face Dataset object with columns ['target', 'binder'].
|
|
"""
|
|
|
|
ds = load_dataset(dataset_name, split=split)
|
|
logger.info(f"Loaded dataset with columns: {ds.column_names}")
|
|
actual_target_col = "receptor"
|
|
actual_binder_col = "peptide"
|
|
|
|
try:
|
|
ds = ds.rename_columns(
|
|
{actual_target_col: "target", actual_binder_col: "binder"}
|
|
)
|
|
ds = ds.remove_columns(
|
|
[col for col in ds.column_names if col not in {"target", "binder"}]
|
|
)
|
|
except ValueError as e:
|
|
logger.error(f"Error renaming columns: {e}")
|
|
logger.error(f"Available columns: {ds.column_names}")
|
|
if (
|
|
actual_target_col in ds.column_names
|
|
and actual_binder_col in ds.column_names
|
|
):
|
|
ds = ds.select_columns([actual_target_col, actual_binder_col])
|
|
ds = ds.rename_columns(
|
|
{actual_target_col: "target", actual_binder_col: "binder"}
|
|
)
|
|
else:
|
|
logger.error(
|
|
f"Could not find expected columns in dataset. Available columns: {ds.column_names}"
|
|
)
|
|
raise ValueError(
|
|
f"Dataset {dataset_name} doesn't have the expected columns. Please check your dataset configuration."
|
|
)
|
|
|
|
return ds
|
|
|
|
|
|
class BinderRow(TypedDict):
|
|
target: str
|
|
binder: str
|
|
|
|
|
|
class BinderBenchConfig(BaseEnvConfig):
|
|
nim_api_key: Optional[str] = Field(None, description="NVIDIA NIM API key")
|
|
nim_api_base_url: str = Field(
|
|
"https://health.api.nvidia.com/v1", description="NIM API base URL"
|
|
)
|
|
api_timeout: int = Field(1800, description="Timeout for NIM API calls")
|
|
polling_interval: int = Field(30, description="Polling interval for NIM jobs")
|
|
output_dir: str = Field(
|
|
default=str(Path(__file__).parent / "outputs"),
|
|
description="Directory to save PDBs, etc.",
|
|
)
|
|
debug_protein_design_calls: bool = Field(
|
|
False,
|
|
description="Enable debug mode for NIM protein API calls, returning mock data.",
|
|
)
|
|
max_retries_per_internal_step: int = Field(
|
|
100,
|
|
description="Max retries for a failed tool call within a workflow step (0 means no retries).",
|
|
)
|
|
dataset_name: str = Field(
|
|
"ronig/protein_binding_sequences", description="Dataset for target sequences"
|
|
)
|
|
target_col: str = Field(
|
|
"receptor", description="Target column name (actual column in the dataset)"
|
|
)
|
|
binder_col: str = Field(
|
|
"peptide", description="Binder column name (actual column in the dataset)"
|
|
)
|
|
|
|
|
|
class BinderBenchEnv(BaseEnv):
|
|
name = "binderbench"
|
|
env_config_cls = BinderBenchConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: BinderBenchConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=False,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config: BinderBenchConfig
|
|
self.process_mode = False
|
|
self.tools = ALL_TOOLS_LIST
|
|
self.output_dir = Path(self.config.output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
self.episodes_state = {}
|
|
self.completed_episode_metrics: List[Dict] = []
|
|
self.rollouts_for_wandb = []
|
|
self.tool_executor = ToolExecutor(
|
|
nim_api_key=self.config.nim_api_key,
|
|
api_timeout=self.config.api_timeout,
|
|
polling_interval=self.config.polling_interval,
|
|
output_dir=self.output_dir,
|
|
debug_protein_design_calls=self.config.debug_protein_design_calls,
|
|
)
|
|
|
|
async def _execute_tool(
|
|
self, tool_name: str, args: Dict, workflow_state: Dict
|
|
) -> Dict:
|
|
"""Delegates tool execution and then updates workflow_state based on the result."""
|
|
|
|
execution_result_package = await self.tool_executor.dispatch_tool_call(
|
|
tool_name, args, workflow_state
|
|
)
|
|
tool_output = execution_result_package.get("tool_output", {})
|
|
state_updates = execution_result_package.get("state_updates", {})
|
|
if state_updates:
|
|
workflow_state.update(state_updates)
|
|
logger.debug(
|
|
f"Workflow {workflow_state['item_id']}: State updated with keys: {list(state_updates.keys())}"
|
|
)
|
|
|
|
return tool_output
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BinderBenchConfig, List[APIServerConfig]]:
|
|
default_yaml_path = (
|
|
Path(__file__).parent / "configs" / "binderbench_default.yaml"
|
|
)
|
|
yaml_config_values = {}
|
|
if default_yaml_path.exists():
|
|
with open(default_yaml_path, "r") as f:
|
|
yaml_config_values = yaml.safe_load(f) or {}
|
|
|
|
env_config = BinderBenchConfig(
|
|
use_wandb=True,
|
|
wandb_name=cls.name,
|
|
nim_api_key=os.environ.get("NVIDIA_NIM_API_KEY"),
|
|
debug_protein_design_calls=yaml_config_values.get(
|
|
"debug_protein_design_calls",
|
|
bool(os.environ.get("DEBUG_PROTEIN_DESIGN_CALLS", False)),
|
|
),
|
|
)
|
|
|
|
llm_api_key = os.environ.get("OPENAI_API_KEY")
|
|
llm_base_url = os.environ.get("OPENAI_API_BASE")
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name=os.environ.get("DEFAULT_LLM_MODEL", "gpt-4-turbo"),
|
|
api_key=llm_api_key,
|
|
base_url=llm_base_url,
|
|
)
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
self.iter = 0
|
|
self.train = load_target_binder_pairs(
|
|
dataset_name=self.config.dataset_name,
|
|
target_col=self.config.target_col,
|
|
binder_col=self.config.binder_col,
|
|
)
|
|
logger.info(f"Loaded {len(self.train)} target-binder pairs for {self.name}.")
|
|
|
|
if not self.config.nim_api_key:
|
|
self.config.nim_api_key = os.environ.get("NVIDIA_NIM_API_KEY")
|
|
if not self.config.nim_api_key:
|
|
logger.warning(
|
|
"NVIDIA NIM API key not set. Protein design functions may not work properly."
|
|
)
|
|
|
|
def _initialize_workflow_state(
|
|
self, item_id: str, target_sequence: str, ground_truth_binder: Optional[str]
|
|
) -> Dict:
|
|
"""Initializes or resets the state for a new workflow."""
|
|
return {
|
|
"item_id": item_id,
|
|
"current_internal_step": 0,
|
|
"target_sequence": target_sequence,
|
|
"ground_truth_binder_sequence": ground_truth_binder,
|
|
"target_pdb_content": None,
|
|
"target_chain_details": None,
|
|
"binder_backbone_pdb_content": None,
|
|
"designed_binder_sequence": None,
|
|
"complex_pdb_content_path": None,
|
|
"af2_multimer_plddt": 0.0,
|
|
"target_structure_predicted": False,
|
|
"binder_backbone_designed": False,
|
|
"binder_sequence_designed": False,
|
|
"complex_evaluated": False,
|
|
"workflow_complete_flag": False,
|
|
"last_tool_success": True,
|
|
"cumulative_reward": 0.0,
|
|
"turn_messages_history": [],
|
|
"retry_count_this_internal_step": 0,
|
|
"previous_tool_error_message": None,
|
|
}
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""
|
|
Provides the initial information for a new protein design workflow.
|
|
Returns an Item tuple: (item_id, initial_target_sequence_info)
|
|
"""
|
|
raw_item: BinderRow = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
|
|
item_id = str(uuid.uuid4())
|
|
target_sequence = raw_item["target"]
|
|
ground_truth_binder = raw_item.get("binder")
|
|
|
|
self.episodes_state[item_id] = self._initialize_workflow_state(
|
|
item_id, target_sequence, ground_truth_binder
|
|
)
|
|
|
|
return item_id
|
|
|
|
def reset_state(self, item_id: str) -> dict:
|
|
"""Retrieves the workflow state for the given item_id."""
|
|
if item_id in self.episodes_state:
|
|
return self.episodes_state[item_id]
|
|
else:
|
|
logger.error(
|
|
f"No state found for item_id {item_id}. Creating a default state."
|
|
)
|
|
return self._initialize_workflow_state(item_id, "", None)
|
|
|
|
async def collect_trajectories(
|
|
self, item_id: str
|
|
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
|
workflow_state = self.episodes_state.get(item_id)
|
|
if not workflow_state:
|
|
logger.error(f"Workflow state for item_id {item_id} not found. Skipping.")
|
|
return None, []
|
|
|
|
if workflow_state.get("workflow_complete_flag"):
|
|
logger.info(f"Workflow for {item_id} already marked complete. Skipping.")
|
|
return None, []
|
|
|
|
is_processing_mode = getattr(self, "process_mode", False) # Check the flag
|
|
|
|
if is_processing_mode:
|
|
all_turns_data_for_jsonl = []
|
|
MAX_INTERNAL_STEPS = 4
|
|
|
|
while workflow_state[
|
|
"current_internal_step"
|
|
] < MAX_INTERNAL_STEPS and not workflow_state.get("workflow_complete_flag"):
|
|
|
|
current_turn_messages: List[Message] = []
|
|
user_prompt_str = construct_user_prompt(workflow_state)
|
|
current_turn_messages.append(
|
|
Message(role="system", content=SYSTEM_PROMPT)
|
|
)
|
|
current_turn_messages.append(
|
|
Message(role="user", content=user_prompt_str)
|
|
)
|
|
|
|
llm_response = await self.server.chat_completion(
|
|
messages=current_turn_messages,
|
|
tools=self.tools,
|
|
tool_choice="auto",
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.5,
|
|
)
|
|
assistant_message_obj = llm_response.choices[0].message
|
|
assistant_content = assistant_message_obj.content or ""
|
|
assistant_tool_calls = []
|
|
if (
|
|
hasattr(assistant_message_obj, "tool_calls")
|
|
and assistant_message_obj.tool_calls
|
|
):
|
|
assistant_tool_calls = [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in assistant_message_obj.tool_calls
|
|
]
|
|
current_turn_messages.append(
|
|
Message(
|
|
role="assistant",
|
|
content=assistant_content,
|
|
tool_calls=(
|
|
assistant_tool_calls if assistant_tool_calls else None
|
|
),
|
|
)
|
|
)
|
|
|
|
tool_error_for_retry_prompt = None
|
|
if assistant_tool_calls:
|
|
tool_call_request = assistant_tool_calls[0]
|
|
tool_name = tool_call_request["function"]["name"]
|
|
try:
|
|
tool_args = json.loads(
|
|
tool_call_request["function"]["arguments"]
|
|
)
|
|
tool_result = await self._execute_tool(
|
|
tool_name, tool_args, workflow_state
|
|
)
|
|
current_turn_messages.append(
|
|
Message(
|
|
role="tool",
|
|
tool_call_id=tool_call_request["id"],
|
|
name=tool_name,
|
|
content=json.dumps(tool_result),
|
|
)
|
|
)
|
|
workflow_state["last_tool_success"] = tool_result.get(
|
|
"success", False
|
|
)
|
|
if not workflow_state["last_tool_success"]:
|
|
tool_error_for_retry_prompt = tool_result.get(
|
|
"error", "Tool execution failed."
|
|
)
|
|
except Exception as e:
|
|
error_msg = f"Error processing tool {tool_name}: {str(e)}"
|
|
current_turn_messages.append(
|
|
Message(
|
|
role="tool",
|
|
tool_call_id=tool_call_request["id"],
|
|
name=tool_name,
|
|
content=error_msg,
|
|
)
|
|
)
|
|
workflow_state["last_tool_success"] = False
|
|
tool_error_for_retry_prompt = error_msg
|
|
else:
|
|
workflow_state["last_tool_success"] = False
|
|
expected_tool_name = {
|
|
0: "AF2",
|
|
1: "RFD",
|
|
2: "PMPNN",
|
|
3: "AF2M",
|
|
}.get(workflow_state["current_internal_step"], "a tool")
|
|
tool_error_for_retry_prompt = (
|
|
f"No tool was called, but {expected_tool_name} was expected."
|
|
)
|
|
|
|
workflow_state["previous_tool_error_message"] = (
|
|
tool_error_for_retry_prompt
|
|
)
|
|
|
|
turn_score_details = self._score_trajectory(
|
|
current_turn_messages, workflow_state
|
|
)
|
|
current_turn_reward = turn_score_details.get("overall_reward", 0.0)
|
|
workflow_state["cumulative_reward"] += current_turn_reward
|
|
|
|
tokenization_result = tokenize_for_trainer(
|
|
self.tokenizer, current_turn_messages, include_messages=False
|
|
)
|
|
all_turns_data_for_jsonl.append(
|
|
{
|
|
"tokens_this_turn": tokenization_result["tokens"],
|
|
"masks_this_turn": tokenization_result["masks"],
|
|
"score_this_turn": current_turn_reward,
|
|
"messages_this_turn": current_turn_messages.copy(),
|
|
"overrides_this_turn": turn_score_details.copy(),
|
|
}
|
|
)
|
|
|
|
if workflow_state["last_tool_success"]:
|
|
workflow_state["current_internal_step"] += 1
|
|
workflow_state["retry_count_this_internal_step"] = 0
|
|
workflow_state["previous_tool_error_message"] = None
|
|
else:
|
|
if workflow_state["current_internal_step"] <= 3:
|
|
workflow_state["retry_count_this_internal_step"] += 1
|
|
if (
|
|
workflow_state["retry_count_this_internal_step"]
|
|
> self.config.max_retries_per_internal_step
|
|
):
|
|
logger.warning(
|
|
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
|
|
f"Max retries ({self.config.max_retries_per_internal_step}) reached. "
|
|
f"Terminating workflow for this item."
|
|
)
|
|
workflow_state["workflow_complete_flag"] = True
|
|
break
|
|
else:
|
|
logger.info(
|
|
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
|
|
f"Failed, attempt {workflow_state['retry_count_this_internal_step']}. "
|
|
f"Retrying same step."
|
|
)
|
|
|
|
else:
|
|
logger.warning(
|
|
f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: "
|
|
f"Failure at non-retryable step. Terminating workflow."
|
|
)
|
|
workflow_state["workflow_complete_flag"] = True
|
|
break
|
|
|
|
if workflow_state["current_internal_step"] >= MAX_INTERNAL_STEPS:
|
|
workflow_state["workflow_complete_flag"] = True
|
|
logger.info(
|
|
f"Workflow {item_id}: All internal steps completed successfully."
|
|
)
|
|
|
|
if not all_turns_data_for_jsonl:
|
|
logger.warning(
|
|
f"Workflow {item_id} in process mode: No turn data collected."
|
|
)
|
|
return None, []
|
|
|
|
html_compatible_messages: List[str] = []
|
|
html_compatible_scores: List[float] = []
|
|
overrides_for_jsonl: List[Dict[str, Any]] = []
|
|
|
|
for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl):
|
|
turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"]
|
|
if turn_data.get("messages_this_turn"):
|
|
for msg_obj in turn_data["messages_this_turn"]:
|
|
content_str = str(msg_obj.get("content", "[No Content]"))
|
|
if msg_obj.get("tool_calls"):
|
|
try:
|
|
tool_calls_str = json.dumps(
|
|
msg_obj.get("tool_calls"), indent=2
|
|
)
|
|
content_str += f"\nTool Calls:\n{tool_calls_str}"
|
|
except TypeError: # Handle non-serializable content if any
|
|
content_str += (
|
|
"\nTool Calls: [Error serializing tool_calls]"
|
|
)
|
|
turn_str_parts.append(
|
|
f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}"
|
|
)
|
|
else:
|
|
turn_str_parts.append("No messages recorded for this turn.")
|
|
|
|
html_compatible_messages.append("\n\n".join(turn_str_parts))
|
|
|
|
turn_score = turn_data.get("overrides_this_turn", {}).get(
|
|
"overall_reward", 0.0
|
|
)
|
|
html_compatible_scores.append(turn_score)
|
|
|
|
overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {}))
|
|
|
|
final_workflow_reward = workflow_state.get("cumulative_reward", 0.0)
|
|
if workflow_state.get("complex_evaluated") and workflow_state.get(
|
|
"last_tool_success"
|
|
):
|
|
final_workflow_reward = (
|
|
all_turns_data_for_jsonl[-1]
|
|
.get("overrides_this_turn", {})
|
|
.get("overall_reward", 0.0)
|
|
)
|
|
|
|
all_tokens_per_turn = [
|
|
turn_data["tokens_this_turn"]
|
|
for turn_data in all_turns_data_for_jsonl
|
|
if turn_data.get("tokens_this_turn")
|
|
]
|
|
all_masks_per_turn = [
|
|
turn_data["masks_this_turn"]
|
|
for turn_data in all_turns_data_for_jsonl
|
|
if turn_data.get("masks_this_turn")
|
|
]
|
|
|
|
if len(all_tokens_per_turn) != len(html_compatible_messages):
|
|
logger.error(
|
|
f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) "
|
|
f"and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic."
|
|
)
|
|
if all_turns_data_for_jsonl and all_tokens_per_turn:
|
|
last_tokens = all_tokens_per_turn[-1]
|
|
last_masks = all_masks_per_turn[-1]
|
|
all_tokens_per_turn = [last_tokens] * len(html_compatible_messages)
|
|
all_masks_per_turn = [last_masks] * len(html_compatible_messages)
|
|
else:
|
|
all_tokens_per_turn = [[] for _ in html_compatible_messages]
|
|
all_masks_per_turn = [[] for _ in html_compatible_messages]
|
|
|
|
process_mode_scored_data = ScoredDataGroup(
|
|
tokens=all_tokens_per_turn,
|
|
masks=all_masks_per_turn,
|
|
messages=html_compatible_messages,
|
|
scores=html_compatible_scores,
|
|
overrides=overrides_for_jsonl,
|
|
group_overrides={
|
|
"group_size": len(html_compatible_messages),
|
|
"item_id": item_id,
|
|
"is_process_mode_full_workflow": True,
|
|
"final_score_for_workflow": final_workflow_reward,
|
|
"target_sequence": workflow_state.get("target_sequence", "N/A"),
|
|
"designed_binder_sequence": workflow_state.get(
|
|
"designed_binder_sequence", "N/A"
|
|
),
|
|
"final_plddt": workflow_state.get("af2_multimer_plddt", 0.0),
|
|
},
|
|
)
|
|
|
|
await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy())
|
|
|
|
self.completed_episode_metrics.append(workflow_state.copy())
|
|
if item_id in self.episodes_state:
|
|
del self.episodes_state[item_id]
|
|
return process_mode_scored_data, []
|
|
|
|
else:
|
|
current_turn_messages_serve: List[Message] = []
|
|
user_prompt_str_serve = construct_user_prompt(workflow_state)
|
|
current_turn_messages_serve.append(
|
|
Message(role="system", content=SYSTEM_PROMPT)
|
|
)
|
|
current_turn_messages_serve.append(
|
|
Message(role="user", content=user_prompt_str_serve)
|
|
)
|
|
|
|
llm_response_serve = await self.server.chat_completion(
|
|
messages=current_turn_messages_serve,
|
|
tools=self.tools,
|
|
tool_choice="auto",
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.5,
|
|
)
|
|
assistant_message_obj_serve = llm_response_serve.choices[0].message
|
|
assistant_content_serve = assistant_message_obj_serve.content or ""
|
|
assistant_tool_calls_serve = []
|
|
if (
|
|
hasattr(assistant_message_obj_serve, "tool_calls")
|
|
and assistant_message_obj_serve.tool_calls
|
|
):
|
|
assistant_tool_calls_serve = [
|
|
{
|
|
"id": tc.id,
|
|
"type": tc.type,
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in assistant_message_obj_serve.tool_calls
|
|
]
|
|
current_turn_messages_serve.append(
|
|
Message(
|
|
role="assistant",
|
|
content=assistant_content_serve,
|
|
tool_calls=(
|
|
assistant_tool_calls_serve
|
|
if assistant_tool_calls_serve
|
|
else None
|
|
),
|
|
)
|
|
)
|
|
|
|
tool_error_for_retry_prompt_serve = None
|
|
if assistant_tool_calls_serve:
|
|
tool_call_request_serve = assistant_tool_calls_serve[0]
|
|
tool_name_serve = tool_call_request_serve["function"]["name"]
|
|
try:
|
|
tool_args_json_str = tool_call_request_serve["function"][
|
|
"arguments"
|
|
]
|
|
tool_args_serve = json.loads(tool_args_json_str)
|
|
tool_result_serve = await self._execute_tool(
|
|
tool_name_serve, tool_args_serve, workflow_state
|
|
)
|
|
current_turn_messages_serve.append(
|
|
Message(
|
|
role="tool",
|
|
tool_call_id=tool_call_request_serve["id"],
|
|
name=tool_name_serve,
|
|
content=json.dumps(tool_result_serve),
|
|
)
|
|
)
|
|
workflow_state["last_tool_success"] = tool_result_serve.get(
|
|
"success", False
|
|
)
|
|
if not workflow_state["last_tool_success"]:
|
|
tool_error_for_retry_prompt_serve = tool_result_serve.get(
|
|
"error", "Tool execution failed."
|
|
)
|
|
except Exception as e:
|
|
error_msg_serve = (
|
|
f"Error processing tool {tool_name_serve}: {str(e)}"
|
|
)
|
|
current_turn_messages_serve.append(
|
|
Message(
|
|
role="tool",
|
|
tool_call_id=tool_call_request_serve["id"],
|
|
name=tool_name_serve,
|
|
content=error_msg_serve,
|
|
)
|
|
)
|
|
workflow_state["last_tool_success"] = False
|
|
tool_error_for_retry_prompt_serve = error_msg_serve
|
|
else:
|
|
workflow_state["last_tool_success"] = False
|
|
expected_tool_name_serve = {
|
|
0: "AF2",
|
|
1: "RFD",
|
|
2: "PMPNN",
|
|
3: "AF2M",
|
|
}.get(workflow_state["current_internal_step"], "a tool")
|
|
tool_error_for_retry_prompt_serve = (
|
|
f"No tool was called, but {expected_tool_name_serve} was expected."
|
|
)
|
|
|
|
workflow_state["previous_tool_error_message"] = (
|
|
tool_error_for_retry_prompt_serve
|
|
)
|
|
|
|
turn_score_details_serve = self._score_trajectory(
|
|
current_turn_messages_serve, workflow_state
|
|
)
|
|
current_turn_reward_serve = turn_score_details_serve.get(
|
|
"overall_reward", 0.0
|
|
)
|
|
workflow_state["cumulative_reward"] += current_turn_reward_serve
|
|
workflow_state["turn_messages_history"].append(
|
|
current_turn_messages_serve.copy()
|
|
)
|
|
|
|
tokenization_result_serve = tokenize_for_trainer(
|
|
self.tokenizer,
|
|
current_turn_messages_serve,
|
|
include_messages=self.config.include_messages,
|
|
)
|
|
scored_data_serve = ScoredDataGroup(
|
|
tokens=[tokenization_result_serve["tokens"]],
|
|
masks=[tokenization_result_serve["masks"]],
|
|
scores=[current_turn_reward_serve],
|
|
messages=(
|
|
[current_turn_messages_serve]
|
|
if self.config.include_messages
|
|
else None
|
|
),
|
|
overrides=[turn_score_details_serve],
|
|
group_overrides={"group_size": 1},
|
|
)
|
|
|
|
backlog_items_serve = []
|
|
if workflow_state["last_tool_success"]:
|
|
workflow_state["current_internal_step"] += 1
|
|
workflow_state["retry_count_this_internal_step"] = 0
|
|
workflow_state["previous_tool_error_message"] = None
|
|
else:
|
|
if workflow_state["current_internal_step"] <= 3:
|
|
workflow_state["retry_count_this_internal_step"] += 1
|
|
if (
|
|
workflow_state["retry_count_this_internal_step"]
|
|
> self.config.max_retries_per_internal_step
|
|
):
|
|
logger.warning(
|
|
f"Workflow {item_id}, Step {workflow_state['current_internal_step']} "
|
|
f"(Serve Mode): Max retries reached. Terminating."
|
|
)
|
|
workflow_state["workflow_complete_flag"] = True
|
|
else:
|
|
logger.warning(
|
|
f"Workflow {item_id}, Step {workflow_state['current_internal_step']} "
|
|
f"(Serve Mode): Failure at non-retryable step. Terminating."
|
|
)
|
|
workflow_state["workflow_complete_flag"] = True
|
|
|
|
if workflow_state["current_internal_step"] < 4 and not workflow_state.get(
|
|
"workflow_complete_flag"
|
|
):
|
|
should_add_to_backlog = False
|
|
if workflow_state["last_tool_success"]:
|
|
should_add_to_backlog = True
|
|
elif (
|
|
workflow_state["current_internal_step"] <= 3
|
|
and workflow_state["retry_count_this_internal_step"]
|
|
<= self.config.max_retries_per_internal_step
|
|
):
|
|
should_add_to_backlog = True
|
|
|
|
if should_add_to_backlog:
|
|
backlog_items_serve.append(item_id)
|
|
else:
|
|
workflow_state["workflow_complete_flag"] = True
|
|
logger.info(
|
|
f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. "
|
|
f"Internal step: {workflow_state['current_internal_step']}"
|
|
)
|
|
|
|
if workflow_state.get("workflow_complete_flag"):
|
|
if item_id in self.episodes_state:
|
|
await self.add_rollouts_for_wandb(
|
|
data_for_log=self.episodes_state[item_id].copy()
|
|
)
|
|
self.completed_episode_metrics.append(
|
|
self.episodes_state[item_id].copy()
|
|
)
|
|
del self.episodes_state[item_id]
|
|
|
|
return scored_data_serve, backlog_items_serve
|
|
|
|
def _score_trajectory(
|
|
self, turn_messages: List[Message], workflow_state: Dict
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Scores a single turn's trajectory based on the specified reward logic.
|
|
- Steps 0-2: Format reward (0.2 for correct & successful tool call, 0 otherwise).
|
|
- Step 3 (AF2-Multimer): Reward based on pLDDT.
|
|
"""
|
|
detailed_scores = {
|
|
"overall_reward": 0.0,
|
|
"raw_plddt": 0.0,
|
|
}
|
|
|
|
internal_step = workflow_state.get("current_internal_step")
|
|
last_tool_success = workflow_state.get("last_tool_success", False)
|
|
assistant_msg_dict = next(
|
|
(m for m in reversed(turn_messages) if m.get("role") == "assistant"), None
|
|
)
|
|
|
|
expected_tool_for_step = {
|
|
0: "predict_target_structure_alphafold2",
|
|
1: "design_binder_backbone_rfdiffusion",
|
|
2: "design_binder_sequence_proteinmpnn",
|
|
3: "evaluate_binder_complex_alphafold2_multimer",
|
|
}.get(internal_step)
|
|
|
|
called_tool_name = None
|
|
if assistant_msg_dict and assistant_msg_dict.get("tool_calls"):
|
|
tool_calls_list = assistant_msg_dict.get("tool_calls")
|
|
if (
|
|
tool_calls_list
|
|
and isinstance(tool_calls_list, list)
|
|
and len(tool_calls_list) > 0
|
|
):
|
|
function_call_dict = tool_calls_list[0].get("function")
|
|
if function_call_dict and isinstance(function_call_dict, dict):
|
|
called_tool_name = function_call_dict.get("name")
|
|
|
|
if internal_step < 3:
|
|
if last_tool_success and called_tool_name == expected_tool_for_step:
|
|
detailed_scores["overall_reward"] = 0.2
|
|
logger.info(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
|
|
f"Correct tool '{called_tool_name}' used successfully. Reward: 0.2"
|
|
)
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
if not last_tool_success and called_tool_name:
|
|
logger.warning(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
|
|
f"Tool '{called_tool_name}' execution failed. Reward: 0.0"
|
|
)
|
|
elif called_tool_name != expected_tool_for_step:
|
|
logger.warning(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
|
|
f"Incorrect tool '{called_tool_name}' used (expected '{expected_tool_for_step}'). "
|
|
f"Reward: 0.0"
|
|
)
|
|
elif not called_tool_name and expected_tool_for_step:
|
|
logger.warning(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step}: "
|
|
f"No tool called, but expected '{expected_tool_for_step}'. Reward: 0.0"
|
|
)
|
|
|
|
elif internal_step == 3:
|
|
if (
|
|
workflow_state.get("complex_evaluated")
|
|
and last_tool_success
|
|
and called_tool_name == expected_tool_for_step
|
|
):
|
|
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
|
|
detailed_scores["raw_plddt"] = plddt
|
|
|
|
if plddt > 90.0:
|
|
detailed_scores["overall_reward"] = 1.0
|
|
elif plddt > 50.0:
|
|
detailed_scores["overall_reward"] = 0.0 + (plddt - 50.0) * (
|
|
1.0 - 0.0
|
|
) / (90.0 - 50.0)
|
|
detailed_scores["overall_reward"] = max(
|
|
0.0, min(detailed_scores["overall_reward"], 1.0)
|
|
)
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
|
|
logger.info(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): "
|
|
f"pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}"
|
|
)
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
logger.warning(
|
|
f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): "
|
|
f"Evaluation failed or wrong tool. Reward: 0.0. Last tool success: {last_tool_success}, "
|
|
f"Called: {called_tool_name}"
|
|
)
|
|
|
|
else:
|
|
logger.error(
|
|
f"Workflow {workflow_state['item_id']}: "
|
|
f"Invalid internal_step {internal_step} in scoring."
|
|
)
|
|
detailed_scores["overall_reward"] = -1.0
|
|
|
|
return detailed_scores
|
|
|
|
async def postprocess_histories(
|
|
self, trajectories: Optional[ScoredDataGroup]
|
|
) -> Optional[ScoredDataGroup]:
|
|
"""
|
|
Post-processes a ScoredDataGroup for a single turn.
|
|
Can be used for final adjustments or filtering if needed.
|
|
"""
|
|
return trajectories
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluate the environment's performance.
|
|
This method is called periodically by the BaseEnv.env_manager.
|
|
For BinderBenchEnv, it will aggregate metrics from completed workflows.
|
|
"""
|
|
logger.info(f"Running evaluation for {self.name}...")
|
|
if not self.completed_episode_metrics:
|
|
logger.info("No completed episodes to evaluate since last evaluation.")
|
|
self.eval_metrics = (
|
|
[]
|
|
) # Ensure eval_metrics is an empty list if no new data
|
|
if self.config.use_wandb:
|
|
await self.wandb_log({}) # Log that no eval data was present this cycle
|
|
return
|
|
|
|
plddts, cumulative_rewards, workflow_successes = [], [], []
|
|
current_eval_episodes = self.completed_episode_metrics.copy()
|
|
for ep_state in current_eval_episodes:
|
|
if ep_state.get("complex_evaluated") and ep_state.get("last_tool_success"):
|
|
plddts.append(ep_state.get("af2_multimer_plddt", 0.0))
|
|
workflow_successes.append(1.0)
|
|
else:
|
|
workflow_successes.append(0.0)
|
|
cumulative_rewards.append(ep_state.get("cumulative_reward", 0.0))
|
|
|
|
self.eval_metrics = [] # Reset class member for current evaluation results
|
|
if plddts:
|
|
self.eval_metrics.append(("eval/avg_plddt", sum(plddts) / len(plddts)))
|
|
if cumulative_rewards:
|
|
self.eval_metrics.append(
|
|
(
|
|
"eval/avg_cumulative_reward",
|
|
sum(cumulative_rewards) / len(cumulative_rewards),
|
|
)
|
|
)
|
|
if workflow_successes:
|
|
self.eval_metrics.append(
|
|
(
|
|
"eval/workflow_success_rate",
|
|
sum(workflow_successes) / len(workflow_successes),
|
|
)
|
|
)
|
|
|
|
logger.info(f"Evaluation complete. Calculated metrics: {self.eval_metrics}")
|
|
|
|
if self.config.use_wandb:
|
|
await self.wandb_log({})
|
|
|
|
self.completed_episode_metrics.clear()
|
|
|
|
async def add_rollouts_for_wandb(
|
|
self,
|
|
scored_data_group: ScoredDataGroup = None,
|
|
item_id: Item = None,
|
|
data_for_log: Dict = None,
|
|
):
|
|
"""Adds a workflow summary to the wandb rollout buffer.
|
|
|
|
This method has two modes of operation:
|
|
1. Direct logging with workflow_state (preferred for detailed logging):
|
|
- Called from within collect_trajectories with data_for_log=workflow_state.copy()
|
|
- This provides maximum detail for logging
|
|
|
|
2. BaseEnv compatibility mode:
|
|
- Called from BaseEnv.handle_send_to_api with scored_data_group and item_id
|
|
- Used automatically by the framework
|
|
- May have less detail if workflow_state was already deleted
|
|
|
|
Args:
|
|
scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv)
|
|
item_id: The item identifier, which is the key to our episodes_state (from BaseEnv)
|
|
data_for_log: Direct workflow state to log (our custom parameter for direct logging)
|
|
"""
|
|
if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"):
|
|
if not hasattr(self, "rollouts_for_wandb"):
|
|
self.rollouts_for_wandb = []
|
|
|
|
workflow_state = None
|
|
|
|
if data_for_log is not None and isinstance(data_for_log, dict):
|
|
workflow_state = data_for_log
|
|
if item_id is None and "item_id" in workflow_state:
|
|
item_id = workflow_state["item_id"]
|
|
|
|
elif item_id is not None and item_id in self.episodes_state:
|
|
workflow_state = self.episodes_state[item_id]
|
|
|
|
if workflow_state is None:
|
|
logger.debug(
|
|
f"No workflow_state available for WandB logging (item_id={item_id})"
|
|
)
|
|
return
|
|
|
|
target_seq = workflow_state.get("target_sequence", "N/A")
|
|
|
|
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
|
|
cumulative_reward = workflow_state.get("cumulative_reward", 0.0)
|
|
|
|
last_turn_messages_str = "No messages."
|
|
try:
|
|
if (
|
|
workflow_state.get("turn_messages_history")
|
|
and len(workflow_state["turn_messages_history"]) > 0
|
|
):
|
|
last_turn_convo = workflow_state["turn_messages_history"][-1]
|
|
last_turn_messages_str = "\n---\n".join(
|
|
[
|
|
f"{msg.get('role', 'unknown')}: {str(msg.get('content', ''))[:200]}..."
|
|
for msg in last_turn_convo
|
|
]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing messages for WandB: {e}")
|
|
last_turn_messages_str = "Error processing messages"
|
|
|
|
target_preview = (
|
|
target_seq[:30] + "..."
|
|
if isinstance(target_seq, str) and len(target_seq) > 30
|
|
else target_seq
|
|
)
|
|
designed_binder_data = workflow_state.get("designed_binder_sequence", "N/A")
|
|
|
|
binder_preview = "N/A"
|
|
if isinstance(designed_binder_data, list) and designed_binder_data:
|
|
first_chain_seq = str(designed_binder_data[0])
|
|
preview_text = (
|
|
first_chain_seq[:30] + "..."
|
|
if len(first_chain_seq) > 30
|
|
else first_chain_seq
|
|
)
|
|
if len(designed_binder_data) > 1:
|
|
binder_preview = f"{len(designed_binder_data)} chains: {preview_text}"
|
|
else:
|
|
binder_preview = preview_text
|
|
elif isinstance(designed_binder_data, str) and designed_binder_data != "N/A":
|
|
binder_preview = (
|
|
designed_binder_data[:30] + "..."
|
|
if len(designed_binder_data) > 30
|
|
else designed_binder_data
|
|
)
|
|
|
|
if item_id is None:
|
|
item_id = workflow_state.get("item_id", "unknown-id")
|
|
|
|
self.rollouts_for_wandb.append(
|
|
(
|
|
str(item_id),
|
|
target_preview,
|
|
binder_preview,
|
|
f"{plddt:.2f}",
|
|
f"{cumulative_reward:.3f}",
|
|
last_turn_messages_str,
|
|
)
|
|
)
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
|
|
"""Creates a wandb.Table from the buffered rollouts."""
|
|
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
|
|
columns = [
|
|
"Item ID",
|
|
"Target (Preview)",
|
|
"Designed Binder (Preview)",
|
|
"Final pLDDT",
|
|
"Cumulative Reward",
|
|
"Last Turn Messages",
|
|
]
|
|
table = wandb.Table(columns=columns)
|
|
for rollout_tuple in self.rollouts_for_wandb:
|
|
table.add_data(*rollout_tuple)
|
|
|
|
table_key = f"env_rollouts/{self.wandb_prepend}/completed_workflows"
|
|
if self.wandb_prepend is None and hasattr(self, "name"):
|
|
table_key = f"env_rollouts/{self.name}/completed_workflows"
|
|
|
|
wandb_metrics[table_key] = table
|
|
self.rollouts_for_wandb.clear()
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
BinderBenchEnv.cli()
|