mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
1468 lines
83 KiB
Python
1468 lines
83 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict, Set
|
|
|
|
import yaml
|
|
import wandb # Add import for wandb
|
|
from dotenv import load_dotenv
|
|
from datasets import load_dataset, Dataset
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item, APIServerConfig, ScoredDataGroup
|
|
from atroposlib.type_definitions import Message
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
# Import model APIs with updated paths
|
|
from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2
|
|
from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion
|
|
from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn
|
|
from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
load_dotenv() # Load environment variables from .env file
|
|
|
|
SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein.
|
|
|
|
You will be guided through a multi-step pipeline:
|
|
|
|
1. Structure prediction (AlphaFold)
|
|
2. Binder backbone generation (RFdiffusion)
|
|
3. Sequence design (ProteinMPNN)
|
|
4. Structure evaluation (AlphaFold-Multimer)
|
|
5. Feedback loop
|
|
|
|
You must always:
|
|
- Respect the required file format for each tool (e.g., FASTA, PDB).
|
|
- Structure your outputs cleanly so they can be parsed and executed programmatically.
|
|
- Be explicit in all configuration steps (e.g., contigs, hotspots).
|
|
- Minimize ambiguity or verbosity; prefer concise and functional outputs.
|
|
- Reason step-by-step when appropriate.
|
|
""" # FIXME Improve
|
|
|
|
def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str, split: str = "train") -> Dataset:
|
|
"""
|
|
Loads and transforms a Hugging Face dataset to contain only 'target' and 'binder' columns.
|
|
|
|
Args:
|
|
dataset_name (str): Hugging Face dataset identifier.
|
|
target_col (str): Name of the column containing target protein sequences.
|
|
binder_col (str): Name of the column containing binder sequences.
|
|
split (str): Dataset split to load.
|
|
|
|
Returns:
|
|
Dataset: Hugging Face Dataset object with columns ['target', 'binder'].
|
|
"""
|
|
ds = load_dataset(dataset_name, split=split)
|
|
|
|
# Check the actual column names in the dataset
|
|
logger.info(f"Loaded dataset with columns: {ds.column_names}")
|
|
|
|
# Map to the actual column names in the dataset
|
|
# Based on the error message, the actual columns are 'receptor' and 'peptide'
|
|
actual_target_col = "receptor" # Assuming this is the target protein
|
|
actual_binder_col = "peptide" # Assuming this is the binder
|
|
|
|
try:
|
|
ds = ds.rename_columns({actual_target_col: "target", actual_binder_col: "binder"})
|
|
ds = ds.remove_columns([col for col in ds.column_names if col not in {"target", "binder"}])
|
|
except ValueError as e:
|
|
logger.error(f"Error renaming columns: {e}")
|
|
logger.error(f"Available columns: {ds.column_names}")
|
|
# If we can't rename, try to select columns directly
|
|
if actual_target_col in ds.column_names and actual_binder_col in ds.column_names:
|
|
ds = ds.select_columns([actual_target_col, actual_binder_col])
|
|
ds = ds.rename_columns({actual_target_col: "target", actual_binder_col: "binder"})
|
|
else:
|
|
# If we still can't get the right columns, create a simple dataset with dummy data
|
|
# This is just to allow testing the environment without the actual dataset
|
|
logger.warning("Using dummy data since the dataset columns don't match the expected format!")
|
|
dummy_data = {
|
|
"target": ["MTEYKLVVVGAGGVGKSALTIQLIQNHFVDEYDPTIEDSYRKQVVIDGETCLLDILDTAGQEEY"] * 10,
|
|
"binder": ["PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASKA"] * 10
|
|
}
|
|
ds = Dataset.from_dict(dummy_data)
|
|
|
|
return ds
|
|
|
|
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
|
|
"""
|
|
Parses PDB content to extract detailed information for each chain.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- chain_details (Dict[str, Dict[str, int]]):
|
|
A dictionary where keys are chain IDs (e.g., "A").
|
|
Each value is another dictionary:
|
|
{
|
|
"min_residue": int, # Smallest residue number found for this chain
|
|
"max_residue": int, # Largest residue number found for this chain
|
|
"length": int # Count of unique C-alpha atoms (residues) in this chain
|
|
}
|
|
- pdb_preview (str): A string preview of the PDB content.
|
|
"""
|
|
chain_info_temp: Dict[str, Dict[str, Union[Set[int], int]]] = {} # Stores residue numbers and CA count for each chain
|
|
atom_lines = []
|
|
header_lines = []
|
|
|
|
# First pass: Collect all residue numbers and CA atoms per chain
|
|
for line in pdb_content.splitlines():
|
|
if line.startswith("ATOM"): # Consider only ATOM records for canonical residues
|
|
atom_lines.append(line)
|
|
chain_id = line[21:22].strip()
|
|
if not chain_id:
|
|
chain_id = " " # Default for blank chain ID, consider how RFDiffusion handles this
|
|
|
|
atom_name = line[12:16].strip()
|
|
|
|
try:
|
|
residue_num = int(line[22:26].strip())
|
|
|
|
if chain_id not in chain_info_temp:
|
|
chain_info_temp[chain_id] = {"residues": set(), "ca_count": 0}
|
|
|
|
chain_info_temp[chain_id]["residues"].add(residue_num)
|
|
if atom_name == "CA":
|
|
chain_info_temp[chain_id]["ca_count"] += 1
|
|
except ValueError:
|
|
logger.warning(f"Could not parse residue number from PDB line: {line}")
|
|
continue
|
|
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
|
|
header_lines.append(line)
|
|
|
|
# Second pass: Calculate min, max, and length from collected data
|
|
chain_details: Dict[str, Dict[str, int]] = {}
|
|
for chain_id, data in chain_info_temp.items():
|
|
if data["residues"]: # Only process if residues were found
|
|
min_res = min(data["residues"])
|
|
max_res = max(data["residues"])
|
|
# Length can be defined in two ways:
|
|
# 1. max_res - min_res + 1 (if contiguous numbering)
|
|
# 2. Count of unique residues (safer for gaps, but AF2 is usually contiguous)
|
|
# 3. Count of C-alpha atoms (good proxy for actual modeled residues)
|
|
# Let's use ca_count as it reflects actual modeled residues.
|
|
# If ca_count is 0 but residues were found (e.g. only HETATMs), this needs thought.
|
|
# For now, prioritizing ca_count.
|
|
length = data["ca_count"] if data["ca_count"] > 0 else len(data["residues"])
|
|
|
|
chain_details[chain_id] = {
|
|
"min_residue": min_res,
|
|
"max_residue": max_res,
|
|
"length": length
|
|
}
|
|
else:
|
|
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
|
|
|
|
|
|
# Construct PDB preview
|
|
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
|
|
remaining_preview_lines = preview_lines - len(preview_str_parts)
|
|
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
|
|
pdb_preview = "\n".join(preview_str_parts)
|
|
if len(pdb_content.splitlines()) > preview_lines:
|
|
pdb_preview += "\n..."
|
|
|
|
return chain_details, pdb_preview
|
|
|
|
def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, int], str]:
|
|
chain_lengths = {}
|
|
current_chain_id = None
|
|
max_residue_num = 0
|
|
atom_lines = []
|
|
header_lines = []
|
|
|
|
for line in pdb_content.splitlines():
|
|
if line.startswith("ATOM") or line.startswith("HETATM"):
|
|
atom_lines.append(line)
|
|
chain_id = line[21:22].strip()
|
|
if not chain_id: # Handle cases where chain ID might be blank
|
|
chain_id = " " # Default to space if blank, or handle as error
|
|
try:
|
|
residue_num = int(line[22:26].strip())
|
|
if current_chain_id != chain_id:
|
|
if current_chain_id is not None: # For previous chain
|
|
chain_lengths[current_chain_id] = max_residue_num
|
|
current_chain_id = chain_id
|
|
max_residue_num = residue_num # Reset for new chain
|
|
else:
|
|
max_residue_num = max(max_residue_num, residue_num)
|
|
except ValueError:
|
|
continue # Skip if residue number is not an int
|
|
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
|
|
header_lines.append(line)
|
|
|
|
if current_chain_id is not None: # Store the last chain's length
|
|
chain_lengths[current_chain_id] = max_residue_num
|
|
|
|
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
|
|
remaining_preview_lines = preview_lines - len(preview_str_parts)
|
|
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
|
|
|
|
pdb_preview = "\n".join(preview_str_parts)
|
|
if len(pdb_content.splitlines()) > preview_lines:
|
|
pdb_preview += "\n..."
|
|
|
|
return chain_lengths, pdb_preview
|
|
|
|
def construct_user_prompt(state: dict) -> str: # state is an item from self.episodes_state
|
|
internal_step = state.get("current_internal_step", 0)
|
|
target_sequence = state.get("target_sequence")
|
|
user_prompt_str = ""
|
|
|
|
if internal_step == 0: # Step 1: Predict Target Structure (AlphaFold2)
|
|
user_prompt_str = (
|
|
f"The target protein sequence is: {target_sequence}. "
|
|
"Your first task is to predict its 3D structure using the 'predict_target_structure_alphafold2' tool. "
|
|
"You must provide the 'sequence' argument."
|
|
)
|
|
elif internal_step == 1: # Step 2: Design Binder Backbone (RFDiffusion)
|
|
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") # Can keep preview for general context
|
|
|
|
# --- NEW CHAIN INFO FORMATTING ---
|
|
chain_details = state.get("target_chain_details", {}) # Get the new detailed info
|
|
if chain_details:
|
|
chain_info_parts = []
|
|
for chain_id, details in chain_details.items():
|
|
min_r = details.get('min_residue', 'N/A')
|
|
max_r = details.get('max_residue', 'N/A')
|
|
l = details.get('length', 'N/A')
|
|
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
|
|
chain_info_str = "\n- ".join(chain_info_parts)
|
|
if chain_info_str:
|
|
chain_info_str = "- " + chain_info_str # Add leading bullet for the first item
|
|
else:
|
|
chain_info_str = "Chain information not available or PDB not yet processed."
|
|
# --- END NEW CHAIN INFO FORMATTING ---
|
|
|
|
user_prompt_str = (
|
|
f"The 3D structure of the target protein has been predicted.\n"
|
|
# Optional: f"Target PDB preview:\n{target_pdb_preview}\n\n"
|
|
f"Target Protein Chain Details:\n{chain_info_str}\n\n" # Use the detailed chain info
|
|
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
|
|
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
|
|
"Examples:\n"
|
|
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
|
|
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
|
|
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
|
|
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
|
|
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
|
|
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
|
|
)
|
|
elif internal_step == 2: # Step 3: Design Binder Sequence (ProteinMPNN)
|
|
# Get detailed binder chain information using the get_pdb_chain_details function
|
|
binder_pdb_content = state.get("binder_backbone_pdb_content")
|
|
if binder_pdb_content:
|
|
binder_chain_details, binder_pdb_preview = get_pdb_chain_details(binder_pdb_content)
|
|
binder_chain_info_str = "\n- ".join([f"Chain {cID} (Residues: {d.get('min_residue','N/A')}-{d.get('max_residue','N/A')}, Length: {d.get('length','N/A')})" for cID, d in binder_chain_details.items()])
|
|
if binder_chain_info_str: binder_chain_info_str = "- " + binder_chain_info_str
|
|
else:
|
|
binder_pdb_preview = "Binder PDB preview not available."
|
|
binder_chain_info_str = "Binder chain information not available."
|
|
|
|
user_prompt_str = (
|
|
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
|
|
f"Binder chain information:\n{binder_chain_info_str}.\n"
|
|
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
|
|
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
|
|
)
|
|
elif internal_step == 3: # Step 4: Evaluate Complex (AlphaFold2-Multimer)
|
|
designed_binder_seq_data = state.get("designed_binder_sequence") # This is List[str]
|
|
|
|
binder_display_str = "Not available"
|
|
if isinstance(designed_binder_seq_data, list) and designed_binder_seq_data:
|
|
if len(designed_binder_seq_data) == 1:
|
|
binder_display_str = designed_binder_seq_data[0]
|
|
else:
|
|
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
|
|
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
|
|
for i, s in enumerate(designed_binder_seq_data)])
|
|
elif isinstance(designed_binder_seq_data, str): # Should not happen with new PMPNN parsing
|
|
binder_display_str = designed_binder_seq_data
|
|
|
|
user_prompt_str = (
|
|
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
|
|
f"The original target sequence was: {target_sequence[:60]}...\n"
|
|
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
|
|
"'evaluate_binder_complex_alphafold2_multimer' tool. "
|
|
"You can optionally specify 'relax_prediction' (default is True)."
|
|
)
|
|
else: # Workflow complete or error
|
|
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
|
|
|
|
# Retry logic should remain the same:
|
|
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
|
|
retry_prefix = "Your previous attempt at this step was not successful. "
|
|
if state.get("previous_tool_error_message"):
|
|
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
|
|
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
|
|
user_prompt_str = retry_prefix + user_prompt_str
|
|
|
|
return user_prompt_str
|
|
|
|
|
|
class BinderRow(TypedDict):
|
|
target: str
|
|
binder: str
|
|
|
|
|
|
# Define a configuration class for BinderBenchEnv
|
|
class BinderBenchConfig(BaseEnvConfig):
|
|
nim_api_key: Optional[str] = Field(None, description="NVIDIA NIM API key")
|
|
nim_api_base_url: str = Field("https://health.api.nvidia.com/v1", description="NIM API base URL")
|
|
api_timeout: int = Field(1800, description="Timeout for NIM API calls") # Increased default
|
|
polling_interval: int = Field(30, description="Polling interval for NIM jobs") # Increased default
|
|
output_dir: str = Field(default=str(Path(__file__).parent / "outputs"), description="Directory to save PDBs, etc.")
|
|
debug_protein_design_calls: bool = Field(False, description="Enable debug mode for NIM protein API calls, returning mock data.")
|
|
max_retries_per_internal_step: int = Field(100, description="Max retries for a failed tool call within a workflow step (0 means no retries).") # Default to 1 retry (2 attempts total)
|
|
# Dataset specific
|
|
dataset_name: str = Field("ronig/protein_binding_sequences", description="Dataset for target sequences")
|
|
target_col: str = Field("receptor", description="Target column name (actual column in the dataset)")
|
|
binder_col: str = Field("peptide", description="Binder column name (actual column in the dataset)")
|
|
# Scoring weights
|
|
metric_weights: Dict[str, float] = Field(
|
|
default={"plddt": 0.3, "ptm": 0.3, "iptm": 0.4},
|
|
description="Weights for combining scoring metrics for complex_quality"
|
|
)
|
|
|
|
|
|
class BinderBenchEnv(BaseEnv):
|
|
name = "binderbench"
|
|
env_config_cls = BinderBenchConfig # Use the new config class
|
|
|
|
def __init__(self, config: BinderBenchConfig, server_configs: List[APIServerConfig], slurm=False, testing=False):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config: BinderBenchConfig # Type hint for convenience
|
|
|
|
# Initialize with process_mode=False (will be set to True when running with process command)
|
|
self.process_mode = False
|
|
|
|
# Tool definitions for LLM function calling
|
|
self.tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "predict_target_structure_alphafold2", # Renamed for clarity
|
|
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"sequence": {"type": "string", "description": "Amino acid sequence of the target protein."},
|
|
},
|
|
"required": ["sequence"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "design_binder_backbone_rfdiffusion",
|
|
"description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
# target_pdb_content will be implicitly taken from state
|
|
"contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."},
|
|
"hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."},
|
|
},
|
|
"required": ["contigs"]
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "design_binder_sequence_proteinmpnn",
|
|
"description": "Designs an amino acid sequence for the generated binder backbone.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
# binder_backbone_pdb_content taken from state
|
|
"sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."}
|
|
},
|
|
"required": [] # sampling_temp is optional
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "evaluate_binder_complex_alphafold2_multimer",
|
|
"description": "Predicts the complex structure of target and designed binder, providing quality scores.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
# target_sequence and binder_sequence taken from state
|
|
"relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."}
|
|
},
|
|
"required": [] # relax_prediction is optional
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
# Ensure output directory exists
|
|
self.output_dir = Path(self.config.output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
self.episodes_state = {} # To store state for each item_id
|
|
self._debug_af2m_call_count = 0 # For debug mode pLDDT alternation
|
|
|
|
self.completed_episode_metrics: List[Dict] = [] # Store completed workflow metrics for evaluation
|
|
self.rollouts_for_wandb = [] # Initialize buffer for WandB rollout data
|
|
|
|
async def _execute_tool(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict:
|
|
"""Executes the specified NIM tool and updates the workflow_state."""
|
|
item_id = workflow_state["item_id"]
|
|
internal_step = workflow_state["current_internal_step"]
|
|
logger.info(f"Workflow {item_id}, Internal Step {internal_step}: Executing tool '{tool_name}' with args: {args}")
|
|
|
|
# Ensure API key is available
|
|
if not self.config.nim_api_key:
|
|
logger.error(f"NIM API key not configured for tool {tool_name}.")
|
|
return {"success": False, "error": "NIM API key not configured."}
|
|
|
|
result = {"success": False, "error": "Unknown tool or execution error."}
|
|
try:
|
|
if tool_name == "predict_target_structure_alphafold2":
|
|
result = await self._run_nim_alphafold2(args, workflow_state)
|
|
elif tool_name == "design_binder_backbone_rfdiffusion":
|
|
result = await self._run_nim_rfdiffusion(args, workflow_state)
|
|
elif tool_name == "design_binder_sequence_proteinmpnn":
|
|
result = await self._run_nim_proteinmpnn(args, workflow_state)
|
|
elif tool_name == "evaluate_binder_complex_alphafold2_multimer":
|
|
result = await self._run_nim_af2_multimer(args, workflow_state)
|
|
else:
|
|
result = {"success": False, "error": f"Unknown tool name: {tool_name}"}
|
|
except Exception as e:
|
|
logger.error(f"Workflow {item_id}, Step {internal_step}: Exception during tool '{tool_name}': {e}", exc_info=True)
|
|
result = {"success": False, "error": str(e)}
|
|
|
|
# The runner methods should have updated workflow_state directly
|
|
return result
|
|
|
|
async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict:
|
|
item_id = workflow_state["item_id"] # Get item_id for logging and unique filenames
|
|
|
|
# ***** START DEBUG MODE LOGIC FOR ALPHAFOLD2 *****
|
|
if self.config.debug_protein_design_calls:
|
|
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
|
|
|
|
# Define the path to your fixed PDB file - use absolute path in the project root
|
|
# Create a Path object for the PDB file in the project root
|
|
project_root = Path(__file__).resolve().parent.parent.parent.parent # One more level up to reach atropos root
|
|
fixed_pdb_path = project_root / "binder_outputs" / "target.pdb"
|
|
|
|
if not fixed_pdb_path.exists():
|
|
logger.error(f"DEBUG MODE ERROR: Fixed PDB file not found at {fixed_pdb_path}. Cannot proceed with mock AF2 output.")
|
|
# Create a dummy PDB content if file not found to prevent downstream errors, but log severe warning
|
|
pdb_content = "HEADER DUMMY PDB FOR DEBUG - TARGET.PDB NOT FOUND\nATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\nTER\nEND\n"
|
|
workflow_state["target_pdb_content"] = pdb_content
|
|
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
|
|
workflow_state["target_chain_details"] = chain_details # Store detailed info
|
|
workflow_state["target_pdb_preview"] = pdb_preview
|
|
workflow_state["target_structure_predicted"] = True # Mark as "predicted" for workflow to proceed
|
|
return {"success": False, "error": f"Debug mode error: {fixed_pdb_path} not found.", "target_pdb_preview": pdb_preview}
|
|
|
|
try:
|
|
with open(fixed_pdb_path, "r") as f:
|
|
pdb_content = f.read()
|
|
|
|
workflow_state["target_pdb_content"] = pdb_content
|
|
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
|
|
workflow_state["target_chain_details"] = chain_details # Store detailed info
|
|
workflow_state["target_pdb_preview"] = pdb_preview
|
|
workflow_state["target_structure_predicted"] = True
|
|
|
|
# Optionally, save a copy of this mock PDB to the usual output location for consistency
|
|
debug_output_pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2_DEBUG.pdb"
|
|
with open(debug_output_pdb_path, "w") as f: f.write(pdb_content)
|
|
logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain details: {chain_details}")
|
|
|
|
return {"success": True, "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "target_pdb_preview": pdb_preview}
|
|
except Exception as e:
|
|
logger.error(f"DEBUG MODE ERROR: Failed to read or process {fixed_pdb_path}: {e}", exc_info=True)
|
|
return {"success": False, "error": f"Debug mode error: Failed processing {fixed_pdb_path}."}
|
|
# ***** END DEBUG MODE LOGIC FOR ALPHAFOLD2 *****
|
|
|
|
# --- Original API call logic ---
|
|
sequence = args.get("sequence")
|
|
if not sequence:
|
|
return {"success": False, "error": "Missing 'sequence' for AlphaFold2."}
|
|
# Ensuring LLM uses the canonical target sequence from state
|
|
if sequence != workflow_state["target_sequence"]:
|
|
logger.warning(f"LLM provided sequence '{sequence[:20]}...' for AF2, but expected target '{workflow_state['target_sequence'][:20]}...'. Using expected target from workflow state.")
|
|
sequence = workflow_state["target_sequence"]
|
|
|
|
# (Rest of your existing _run_nim_alphafold2 logic for actual API call...)
|
|
api_result = await call_alphafold2(
|
|
sequence=sequence, api_key=self.config.nim_api_key,
|
|
timeout=self.config.api_timeout, polling_interval=self.config.polling_interval
|
|
)
|
|
if api_result and isinstance(api_result, list) and api_result[0]:
|
|
pdb_content = api_result[0]
|
|
workflow_state["target_pdb_content"] = pdb_content
|
|
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
|
|
workflow_state["target_chain_details"] = chain_details # Store detailed info
|
|
workflow_state["target_pdb_preview"] = pdb_preview
|
|
workflow_state["target_structure_predicted"] = True
|
|
pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2.pdb"
|
|
with open(pdb_path, "w") as f: f.write(pdb_content)
|
|
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
|
|
return {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview}
|
|
else:
|
|
logger.error(f"Workflow {item_id}: AlphaFold2 call failed or returned unexpected data: {api_result}")
|
|
return {"success": False, "error": "AlphaFold2 prediction failed."}
|
|
|
|
async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict:
|
|
target_pdb_content = workflow_state.get("target_pdb_content")
|
|
contigs = args.get("contigs")
|
|
if not target_pdb_content: return {"success": False, "error": "Target PDB not found in state for RFDiffusion."}
|
|
if not contigs: return {"success": False, "error": "Missing 'contigs' for RFDiffusion."}
|
|
|
|
hotspot_residues = args.get("hotspot_residues") # Optional
|
|
item_id = workflow_state["item_id"]
|
|
api_result = await call_rfdiffusion(
|
|
input_pdb=target_pdb_content, api_key=self.config.nim_api_key,
|
|
contigs=contigs, hotspot_res=hotspot_residues,
|
|
timeout=self.config.api_timeout, polling_interval=self.config.polling_interval
|
|
# Add other RFD specific params
|
|
)
|
|
if api_result and "output_pdb" in api_result:
|
|
binder_pdb = api_result["output_pdb"]
|
|
workflow_state["binder_backbone_pdb_content"] = binder_pdb
|
|
workflow_state["binder_backbone_designed"] = True
|
|
pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{workflow_state['current_internal_step']}_rfd.pdb"
|
|
with open(pdb_path, "w") as f: f.write(binder_pdb)
|
|
logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}")
|
|
# NO LONGER INCREMENT current_internal_step HERE - collect_trajectories will handle this
|
|
return {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb[:150] + "..."}
|
|
else:
|
|
logger.error(f"Workflow {item_id}: RFDiffusion call failed or returned unexpected data: {api_result}")
|
|
return {"success": False, "error": "RFDiffusion failed."}
|
|
|
|
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
|
|
binder_pdb = workflow_state.get("binder_backbone_pdb_content")
|
|
if not binder_pdb:
|
|
return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
|
|
|
|
sampling_temp_list = args.get("sampling_temp", [0.1])
|
|
item_id = workflow_state["item_id"]
|
|
|
|
api_result = await call_proteinmpnn(
|
|
input_pdb=binder_pdb, api_key=self.config.nim_api_key,
|
|
sampling_temp=sampling_temp_list,
|
|
timeout=self.config.api_timeout, polling_interval=self.config.polling_interval
|
|
)
|
|
|
|
if not (api_result and "mfasta" in api_result):
|
|
logger.error(f"Workflow {item_id}: ProteinMPNN call failed or returned unexpected data: {api_result}")
|
|
return {"success": False, "error": "ProteinMPNN call failed or no mfasta in result."}
|
|
|
|
fasta_content = api_result["mfasta"]
|
|
logger.info(f"CRITICAL_DEBUG: ProteinMPNN raw mfasta output for item {item_id}:\n{fasta_content}")
|
|
|
|
# --- FASTA Parsing Logic to find best sequence by global_score ---
|
|
entries: List[Tuple[float, str, str]] = [] # (global_score, header, sequence_line)
|
|
current_header = None
|
|
current_sequence_parts: List[str] = []
|
|
|
|
for line_content in fasta_content.splitlines():
|
|
line = line_content.strip()
|
|
if not line: continue
|
|
|
|
if line.startswith(">"):
|
|
if current_header and current_sequence_parts: # Process previous entry
|
|
full_sequence_line = "".join(current_sequence_parts)
|
|
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
|
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
|
entries.append((global_score, current_header, full_sequence_line))
|
|
current_header = line
|
|
current_sequence_parts = []
|
|
else:
|
|
current_sequence_parts.append(line)
|
|
|
|
if current_header and current_sequence_parts: # Process the last entry
|
|
full_sequence_line = "".join(current_sequence_parts)
|
|
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
|
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
|
entries.append((global_score, current_header, full_sequence_line))
|
|
|
|
if not entries:
|
|
logger.error(f"Workflow {item_id}: No sequences found in ProteinMPNN mfasta output.")
|
|
return {"success": False, "error": "No sequences parsed from ProteinMPNN mfasta."}
|
|
|
|
# Sort by global_score (descending) and select the best
|
|
entries.sort(key=lambda x: x[0], reverse=True)
|
|
best_global_score, best_header, best_full_sequence_line = entries[0]
|
|
|
|
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}'")
|
|
logger.info(f"Workflow {item_id}: Corresponding sequence line: '{best_full_sequence_line}'")
|
|
|
|
# Split the selected sequence line by '/' to handle potential chainbreaks
|
|
parsed_binder_chains: List[str] = [
|
|
seq_part.strip() for seq_part in best_full_sequence_line.split('/') if seq_part.strip()
|
|
]
|
|
|
|
if not parsed_binder_chains:
|
|
error_msg = f"Splitting best PMPNN sequence ('{best_full_sequence_line}') by '/' yielded no valid chains."
|
|
logger.error(f"Workflow {item_id}: {error_msg}")
|
|
return {"success": False, "error": error_msg}
|
|
|
|
# Validate each parsed chain (ensure they are valid protein sequences)
|
|
for seq_idx, seq_part in enumerate(parsed_binder_chains):
|
|
if not (seq_part and seq_part.isalpha() and seq_part.isupper()):
|
|
error_msg = f"Parsed binder chain {seq_idx+1} ('{seq_part[:30]}...') contains invalid characters or is empty."
|
|
logger.error(f"Workflow {item_id}: {error_msg}")
|
|
return {"success": False, "error": error_msg}
|
|
|
|
workflow_state["designed_binder_sequence"] = parsed_binder_chains # Store as List[str]
|
|
workflow_state["binder_sequence_designed"] = True
|
|
|
|
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta"
|
|
with open(fasta_path, "w") as f: f.write(fasta_content) # Save original full FASTA
|
|
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
|
|
|
|
preview = parsed_binder_chains[0][:60] + "..." if len(parsed_binder_chains[0]) > 60 else parsed_binder_chains[0]
|
|
if len(parsed_binder_chains) > 1:
|
|
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
|
|
|
|
return {
|
|
"success": True,
|
|
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
|
|
"designed_binder_sequence_list": parsed_binder_chains,
|
|
"designed_binder_sequence_preview": preview
|
|
}
|
|
|
|
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
|
|
target_seq = workflow_state.get("target_sequence")
|
|
binder_seq_data = workflow_state.get("designed_binder_sequence")
|
|
|
|
if not target_seq:
|
|
return {"success": False, "error": "Missing target sequence for AlphaFold2-Multimer."}
|
|
|
|
if not binder_seq_data:
|
|
return {"success": False, "error": "Missing binder sequence for AlphaFold2-Multimer."}
|
|
|
|
# Handle binder_seq_data which could now be either a List[str] or a single string (for backward compatibility)
|
|
binder_sequences = []
|
|
if isinstance(binder_seq_data, list):
|
|
binder_sequences = binder_seq_data
|
|
elif isinstance(binder_seq_data, str):
|
|
binder_sequences = [binder_seq_data] # Wrap in list
|
|
else:
|
|
return {"success": False, "error": f"Unexpected type for binder sequence: {type(binder_seq_data)}"}
|
|
|
|
if not binder_sequences:
|
|
return {"success": False, "error": "Empty binder sequence list for AlphaFold2-Multimer."}
|
|
|
|
relax = args.get("relax_prediction", True)
|
|
item_id = workflow_state["item_id"]
|
|
|
|
# Log all sequences for debugging
|
|
total_binder_length = sum(len(seq) for seq in binder_sequences)
|
|
logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and {len(binder_sequences)} binder chain(s) (total len {total_binder_length}).")
|
|
|
|
# Check if in debug mode
|
|
if self.config.debug_protein_design_calls:
|
|
# Increment the counter for alternating results
|
|
self._debug_af2m_call_count += 1
|
|
logger.warning(f"DEBUG MODE: Using mock data for AlphaFold2-Multimer (call #{self._debug_af2m_call_count})")
|
|
|
|
# Create mock results that alternate between high and low quality scores
|
|
# For odd-numbered calls (1, 3, 5...) - return high quality
|
|
# For even-numbered calls (2, 4, 6...) - return low quality
|
|
if self._debug_af2m_call_count % 2 == 1: # Odd calls
|
|
mock_plddt = 87.5 # Good score
|
|
success_message = "DEBUG MODE: Returning high-quality mock results"
|
|
else: # Even calls
|
|
mock_plddt = 45.2 # Poor score
|
|
success_message = "DEBUG MODE: Returning low-quality mock results"
|
|
|
|
# Create a mock PDB file path
|
|
mock_pdb_path = self.output_dir / f"mock_complex_{item_id}_af2m.pdb"
|
|
with open(mock_pdb_path, "w") as f:
|
|
f.write(f"MOCK PDB FILE with pLDDT {mock_plddt}\nFor debug purposes only.\n")
|
|
|
|
# Update workflow state with the mock values
|
|
workflow_state["complex_pdb_content_path"] = str(mock_pdb_path)
|
|
workflow_state["af2_multimer_plddt"] = mock_plddt
|
|
workflow_state["af2_multimer_ptm"] = 0.0
|
|
workflow_state["af2_multimer_iptm"] = 0.0
|
|
workflow_state["complex_evaluated"] = True
|
|
|
|
logger.info(f"Workflow {item_id}: {success_message}. Mock pLDDT: {mock_plddt:.2f}")
|
|
return {
|
|
"success": True,
|
|
"message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
|
|
"plddt": mock_plddt,
|
|
"ptm": 0.0,
|
|
"iptm": 0.0,
|
|
"complex_file_path": str(mock_pdb_path)
|
|
}
|
|
|
|
# Non-debug mode: proceed with actual API call
|
|
# Create a list with target sequence as first element, followed by all binder sequences
|
|
all_sequences = [target_seq] + binder_sequences
|
|
|
|
logger.info(f"Workflow {item_id}: Calling AlphaFold2-Multimer with {len(all_sequences)} sequences: "
|
|
f"1 target ({len(target_seq)} aa) + {len(binder_sequences)} binder chains.")
|
|
|
|
api_result = await call_alphafold2_multimer(
|
|
sequences=all_sequences, # Pass all sequences as a flat list: [target_seq, binder_seq1, binder_seq2, ...]
|
|
api_key=self.config.nim_api_key,
|
|
relax_prediction=relax,
|
|
timeout=self.config.api_timeout,
|
|
polling_interval=self.config.polling_interval
|
|
)
|
|
|
|
if api_result and api_result.get("structures") and len(api_result["structures"]) > 0:
|
|
# Assuming the first structure in the list is the one we care about (e.g., rank 1 model)
|
|
# Or if NIM usually returns only one model's PDB in the .response
|
|
primary_structure_info = api_result["structures"][0]
|
|
|
|
plddt = primary_structure_info.get("average_plddt", 0.0)
|
|
# ipTM and pTM are explicitly set to None or not parsed by _process_nvidia_zip_output
|
|
# So, we don't expect them here unless _process_nvidia_zip_output is changed to parse them.
|
|
# For now, focus on pLDDT.
|
|
|
|
workflow_state["complex_pdb_content_path"] = primary_structure_info.get("saved_pdb_path") # Path to the extracted PDB
|
|
workflow_state["af2_multimer_plddt"] = plddt
|
|
workflow_state["af2_multimer_ptm"] = 0.0 # Explicitly 0 or None if not calculating
|
|
workflow_state["af2_multimer_iptm"] = 0.0 # Explicitly 0 or None
|
|
workflow_state["complex_evaluated"] = True
|
|
|
|
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Average pLDDT: {plddt:.2f}")
|
|
return {
|
|
"success": True,
|
|
"message": f"AlphaFold2-Multimer evaluation complete. Average pLDDT: {plddt:.2f}.",
|
|
"plddt": plddt,
|
|
"ptm": 0.0, # Reflect that we are not using it from here
|
|
"iptm": 0.0, # Reflect that we are not using it from here
|
|
"complex_file_path": str(primary_structure_info.get("saved_pdb_path"))
|
|
}
|
|
else:
|
|
error_msg = "AlphaFold2-Multimer call failed or did not return expected structure data."
|
|
if api_result and "error" in api_result: # If call_alphafold2_multimer returned an error dict
|
|
error_msg = api_result["error"]
|
|
logger.error(f"Workflow {item_id}: {error_msg}. API result: {api_result}")
|
|
workflow_state["complex_evaluated"] = False # Mark as not evaluated
|
|
return {"success": False, "error": error_msg}
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[BinderBenchConfig, List[APIServerConfig]]:
|
|
# Set defaults from config or environment variables
|
|
|
|
# Load from a potential binderbench_default.yaml file if it exists
|
|
default_yaml_path = Path(__file__).parent / "configs" / "binderbench_default.yaml"
|
|
yaml_config_values = {}
|
|
if default_yaml_path.exists():
|
|
with open(default_yaml_path, 'r') as f:
|
|
yaml_config_values = yaml.safe_load(f) or {}
|
|
|
|
# Create environment config with priority order:
|
|
# 1. Environment variables (e.g., NIM API key)
|
|
# 2. YAML config values
|
|
# 3. Default values from Field definitions
|
|
env_config = BinderBenchConfig(
|
|
# --- BaseEnvConfig fields relevant to WandB ---
|
|
use_wandb=True, # Enable WandB by default
|
|
wandb_name=cls.name, # Uses BinderBenchEnv.name as the base for run names
|
|
# num_rollouts_to_keep and num_rollouts_per_group_for_logging are already in BaseEnvConfig
|
|
# include_messages will be True by default for process_mode, can be overridden for serve
|
|
|
|
# --- BinderBenchConfig specific fields ---
|
|
nim_api_key=os.environ.get("NVIDIA_NIM_API_KEY"),
|
|
debug_protein_design_calls=yaml_config_values.get(
|
|
"debug_protein_design_calls",
|
|
bool(os.environ.get("DEBUG_PROTEIN_DESIGN_CALLS", False))
|
|
),
|
|
# Other config properties use defaults from Field definitions
|
|
)
|
|
|
|
# Setup default server configs
|
|
llm_api_key = os.environ.get("OPENAI_API_KEY")
|
|
llm_base_url = os.environ.get("OPENAI_API_BASE")
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name=os.environ.get("DEFAULT_LLM_MODEL", "gpt-4-turbo"),
|
|
api_key=llm_api_key,
|
|
base_url=llm_base_url # Will be None if OPENAI_API_BASE not set
|
|
)
|
|
]
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
self.iter = 0
|
|
self.train = load_target_binder_pairs(
|
|
dataset_name=self.config.dataset_name, # Use config
|
|
target_col=self.config.target_col, # Use config
|
|
binder_col=self.config.binder_col # Use config
|
|
)
|
|
# self.train.shuffle() # Shuffle is good, but might make iter less predictable for debugging
|
|
logger.info(f"Loaded {len(self.train)} target-binder pairs for {self.name}.")
|
|
|
|
# Validate API key
|
|
if not self.config.nim_api_key:
|
|
self.config.nim_api_key = os.environ.get("NVIDIA_NIM_API_KEY")
|
|
if not self.config.nim_api_key:
|
|
logger.warning("NVIDIA NIM API key not set. Protein design functions may not work properly.")
|
|
|
|
def _initialize_workflow_state(self, item_id: str, target_sequence: str, ground_truth_binder: Optional[str]) -> Dict:
|
|
"""Initializes or resets the state for a new workflow."""
|
|
return {
|
|
"item_id": item_id,
|
|
"current_internal_step": 0,
|
|
"target_sequence": target_sequence,
|
|
"ground_truth_binder_sequence": ground_truth_binder, # Store for final evaluation
|
|
"target_pdb_content": None,
|
|
"target_chain_details": None, # Store detailed chain information
|
|
"binder_backbone_pdb_content": None,
|
|
"designed_binder_sequence": None,
|
|
"complex_pdb_content_path": None, # Path to AF2-Multimer output
|
|
"af2_multimer_plddt": 0.0,
|
|
"af2_multimer_ptm": 0.0,
|
|
"af2_multimer_iptm": 0.0,
|
|
"target_structure_predicted": False,
|
|
"binder_backbone_designed": False,
|
|
"binder_sequence_designed": False,
|
|
"complex_evaluated": False,
|
|
"workflow_complete_flag": False, # Flag to mark end of workflow
|
|
"last_tool_success": True, # Track if the last tool call was successful
|
|
"cumulative_reward": 0.0, # For multi-step reward accumulation
|
|
"turn_messages_history": [], # Store list of (List[Message]) for each turn
|
|
"retry_count_this_internal_step": 0, # ***** ADDED: Tracks retries for the current internal_step *****
|
|
"previous_tool_error_message": None, # ***** ADDED: To inform LLM on retry *****
|
|
}
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""
|
|
Provides the initial information for a new protein design workflow.
|
|
Returns an Item tuple: (item_id, initial_target_sequence_info)
|
|
"""
|
|
raw_item: BinderRow = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
|
|
item_id = str(uuid.uuid4())
|
|
target_sequence = raw_item["target"]
|
|
ground_truth_binder = raw_item.get("binder") # May not always be used for de novo
|
|
|
|
# Store the initial state for this new workflow
|
|
self.episodes_state[item_id] = self._initialize_workflow_state(item_id, target_sequence, ground_truth_binder)
|
|
|
|
# The "item" for Atropos's collect_trajectories will just be the item_id.
|
|
# The actual data is pulled from self.episodes_state[item_id].
|
|
return item_id # Item is now just the ID. Initial step is always 0 for a new workflow.
|
|
|
|
# reset_state is effectively handled by _initialize_workflow_state and get_next_item
|
|
def reset_state(self, item_id: str) -> dict:
|
|
"""Retrieves the workflow state for the given item_id."""
|
|
if item_id in self.episodes_state:
|
|
return self.episodes_state[item_id]
|
|
else:
|
|
# This should ideally never happen
|
|
logger.error(f"No state found for item_id {item_id}. Creating a default state.")
|
|
return self._initialize_workflow_state(item_id, "", None) # Empty default state
|
|
|
|
async def collect_trajectories(self, item_id: str) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
|
workflow_state = self.episodes_state.get(item_id)
|
|
if not workflow_state:
|
|
logger.error(f"Workflow state for item_id {item_id} not found. Skipping.")
|
|
return None, []
|
|
|
|
if workflow_state.get("workflow_complete_flag"):
|
|
logger.info(f"Workflow for {item_id} already marked complete. Skipping.")
|
|
# Optionally, clean up here if you don't want to re-process completed items
|
|
# if item_id in self.episodes_state: del self.episodes_state[item_id]
|
|
return None, []
|
|
|
|
is_processing_mode = getattr(self, 'process_mode', False) # Check the flag
|
|
|
|
if is_processing_mode:
|
|
# --- PROCESS MODE: Run full workflow, aggregate all turns ---
|
|
all_turns_data_for_jsonl = [] # To store data for each turn for one JSONL line
|
|
MAX_INTERNAL_STEPS = 4 # AF2, RFD, PMPNN, AF2M
|
|
|
|
while workflow_state["current_internal_step"] < MAX_INTERNAL_STEPS and \
|
|
not workflow_state.get("workflow_complete_flag"):
|
|
|
|
# Construct prompt (will include retry info if applicable)
|
|
current_turn_messages: List[Message] = []
|
|
user_prompt_str = construct_user_prompt(workflow_state) # Uses current state including retry info
|
|
current_turn_messages.append(Message(role="system", content=SYSTEM_PROMPT))
|
|
current_turn_messages.append(Message(role="user", content=user_prompt_str))
|
|
|
|
# LLM Call
|
|
llm_response = await self.server.chat_completion(
|
|
messages=current_turn_messages, tools=self.tools, tool_choice="auto", n=1,
|
|
max_tokens=self.config.max_token_length, temperature=0.5
|
|
)
|
|
assistant_message_obj = llm_response.choices[0].message
|
|
assistant_content = assistant_message_obj.content or ""
|
|
assistant_tool_calls = []
|
|
if hasattr(assistant_message_obj, 'tool_calls') and assistant_message_obj.tool_calls:
|
|
assistant_tool_calls = [
|
|
{"id": tc.id, "type": tc.type, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
|
|
for tc in assistant_message_obj.tool_calls
|
|
]
|
|
current_turn_messages.append(Message(role="assistant", content=assistant_content, tool_calls=assistant_tool_calls if assistant_tool_calls else None))
|
|
|
|
# Tool Execution
|
|
tool_error_for_retry_prompt = None
|
|
if assistant_tool_calls:
|
|
tool_call_request = assistant_tool_calls[0]
|
|
tool_name = tool_call_request["function"]["name"]
|
|
try:
|
|
tool_args = json.loads(tool_call_request["function"]["arguments"])
|
|
tool_result = await self._execute_tool(tool_name, tool_args, workflow_state)
|
|
current_turn_messages.append(Message(role="tool", tool_call_id=tool_call_request["id"] , name=tool_name, content=json.dumps(tool_result)))
|
|
workflow_state["last_tool_success"] = tool_result.get("success", False)
|
|
if not workflow_state["last_tool_success"]:
|
|
tool_error_for_retry_prompt = tool_result.get("error", "Tool execution failed.")
|
|
except Exception as e:
|
|
error_msg = f"Error processing tool {tool_name}: {str(e)}"
|
|
current_turn_messages.append(Message(role="tool", tool_call_id=tool_call_request["id"], name=tool_name, content=error_msg))
|
|
workflow_state["last_tool_success"] = False
|
|
tool_error_for_retry_prompt = error_msg
|
|
else: # No tool called
|
|
workflow_state["last_tool_success"] = False
|
|
expected_tool_name = {0:"AF2",1:"RFD",2:"PMPNN",3:"AF2M"}.get(workflow_state["current_internal_step"], "a tool")
|
|
tool_error_for_retry_prompt = f"No tool was called, but {expected_tool_name} was expected."
|
|
|
|
workflow_state["previous_tool_error_message"] = tool_error_for_retry_prompt
|
|
|
|
# Scoring and Accumulation for JSONL
|
|
turn_score_details = self._score_trajectory(current_turn_messages, workflow_state)
|
|
current_turn_reward = turn_score_details.get("overall_reward", 0.0)
|
|
workflow_state["cumulative_reward"] += current_turn_reward
|
|
|
|
tokenization_result = tokenize_for_trainer(self.tokenizer, current_turn_messages, include_messages=False)
|
|
all_turns_data_for_jsonl.append({
|
|
"tokens_this_turn": tokenization_result["tokens"],
|
|
"masks_this_turn": tokenization_result["masks"],
|
|
"score_this_turn": current_turn_reward,
|
|
"messages_this_turn": current_turn_messages.copy(),
|
|
"overrides_this_turn": turn_score_details.copy()
|
|
})
|
|
|
|
# Workflow Progression / Retry Logic for process_mode
|
|
if workflow_state["last_tool_success"]:
|
|
workflow_state["current_internal_step"] += 1
|
|
workflow_state["retry_count_this_internal_step"] = 0 # Reset for new step
|
|
workflow_state["previous_tool_error_message"] = None
|
|
else: # Tool call failed or was incorrect
|
|
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
|
|
workflow_state["retry_count_this_internal_step"] += 1
|
|
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
|
|
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Max retries ({self.config.max_retries_per_internal_step}) reached. Terminating workflow for this item.")
|
|
workflow_state["workflow_complete_flag"] = True # Failed to progress
|
|
break # Exit the internal while loop
|
|
else:
|
|
logger.info(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failed, attempt {workflow_state['retry_count_this_internal_step']}. Retrying same step.")
|
|
# Loop continues, construct_user_prompt will use retry info
|
|
else: # Should never reach here with MAX_INTERNAL_STEPS = 4, but keeping for safety
|
|
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at non-retryable step. Terminating workflow.")
|
|
workflow_state["workflow_complete_flag"] = True
|
|
break # Exit the internal while loop
|
|
|
|
if workflow_state["current_internal_step"] >= MAX_INTERNAL_STEPS:
|
|
workflow_state["workflow_complete_flag"] = True
|
|
logger.info(f"Workflow {item_id}: All internal steps completed successfully.")
|
|
# No break here, loop condition will handle it
|
|
|
|
# After the internal while loop (for process mode)
|
|
if not all_turns_data_for_jsonl:
|
|
logger.warning(f"Workflow {item_id} in process mode: No turn data collected.")
|
|
return None, []
|
|
|
|
# --- Start of Fix for jsonl2html ---
|
|
html_compatible_messages: List[str] = []
|
|
html_compatible_scores: List[float] = []
|
|
# `overrides_for_jsonl` will store the detailed scoring dict for each turn,
|
|
# matching the structure of `html_compatible_messages` and `html_compatible_scores`.
|
|
overrides_for_jsonl: List[Dict[str, Any]] = []
|
|
|
|
|
|
for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl):
|
|
# Format messages for this turn into a single readable string
|
|
turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"]
|
|
if turn_data.get("messages_this_turn"):
|
|
for msg_obj in turn_data["messages_this_turn"]:
|
|
content_str = str(msg_obj.get("content", "[No Content]"))
|
|
if msg_obj.get("tool_calls"):
|
|
try:
|
|
tool_calls_str = json.dumps(msg_obj.get("tool_calls"), indent=2)
|
|
content_str += f"\nTool Calls:\n{tool_calls_str}"
|
|
except TypeError: # Handle non-serializable content if any
|
|
content_str += f"\nTool Calls: [Error serializing tool_calls]"
|
|
turn_str_parts.append(f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}")
|
|
else:
|
|
turn_str_parts.append("No messages recorded for this turn.")
|
|
|
|
html_compatible_messages.append("\n\n".join(turn_str_parts))
|
|
|
|
# Get the score for this specific turn
|
|
turn_score = turn_data.get("overrides_this_turn", {}).get("overall_reward", 0.0)
|
|
html_compatible_scores.append(turn_score)
|
|
|
|
# Add the detailed scoring dictionary for this turn
|
|
overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {}))
|
|
|
|
|
|
final_workflow_reward = workflow_state.get("cumulative_reward", 0.0)
|
|
# If the complex was evaluated successfully, the last turn's reward is the final one.
|
|
if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"):
|
|
final_workflow_reward = all_turns_data_for_jsonl[-1].get("overrides_this_turn", {}).get("overall_reward", 0.0)
|
|
|
|
# For the ScoredDataGroup that will be handled by BaseEnv
|
|
# We need tokens and masks for each "message" (turn) if we want BaseEnv to consider it valid
|
|
# For simplicity, we can just repeat the last turn's tokens/masks, or use placeholders
|
|
# if actual per-turn tokens aren't critical for the JSONL's main purpose (which is visualization via messages/scores).
|
|
# Let's create placeholder tokens/masks if full history isn't needed by the trainer for process_mode.
|
|
# Or, better, store actual tokens for each turn if available.
|
|
|
|
all_tokens_per_turn = [turn_data["tokens_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("tokens_this_turn")]
|
|
all_masks_per_turn = [turn_data["masks_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("masks_this_turn")]
|
|
|
|
# Ensure all_tokens_per_turn and all_masks_per_turn have same length as html_compatible_messages
|
|
# If some turns didn't produce tokens (e.g. error), we might need to pad or handle.
|
|
# For now, assuming all_turns_data_for_jsonl consistently has tokens/masks for each entry that contributes to html_compatible_messages.
|
|
if len(all_tokens_per_turn) != len(html_compatible_messages):
|
|
logger.error(f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic.")
|
|
# Fallback: repeat last turn's tokens if necessary, though this isn't ideal.
|
|
if all_turns_data_for_jsonl and all_tokens_per_turn:
|
|
last_tokens = all_tokens_per_turn[-1]
|
|
last_masks = all_masks_per_turn[-1]
|
|
all_tokens_per_turn = [last_tokens] * len(html_compatible_messages)
|
|
all_masks_per_turn = [last_masks] * len(html_compatible_messages)
|
|
else: # No token data at all
|
|
all_tokens_per_turn = [[] for _ in html_compatible_messages]
|
|
all_masks_per_turn = [[] for _ in html_compatible_messages]
|
|
|
|
# This is the ScoredDataGroup that will be written to JSONL by BaseEnv
|
|
process_mode_scored_data = ScoredDataGroup(
|
|
tokens=all_tokens_per_turn, # List of token lists, one for each turn
|
|
masks=all_masks_per_turn, # List of mask lists, one for each turn
|
|
|
|
# These are critical for jsonl2html
|
|
messages=html_compatible_messages, # List of strings, one per turn
|
|
scores=html_compatible_scores, # List of floats, one per turn
|
|
|
|
# Store detailed overrides per turn, matching the length of messages/scores
|
|
overrides=overrides_for_jsonl,
|
|
|
|
group_overrides={
|
|
"group_size": len(html_compatible_messages), # Effective group size is number of turns
|
|
"item_id": item_id,
|
|
"is_process_mode_full_workflow": True,
|
|
"final_score_for_workflow": final_workflow_reward, # Store the overall workflow score here
|
|
"target_sequence": workflow_state.get("target_sequence", "N/A"),
|
|
"designed_binder_sequence": workflow_state.get("designed_binder_sequence", "N/A"),
|
|
"final_plddt": workflow_state.get("af2_multimer_plddt", 0.0)
|
|
}
|
|
)
|
|
# --- End of Fix for jsonl2html ---
|
|
|
|
# Log detailed workflow state to WandB (this call should use workflow_state directly)
|
|
await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy()) # Keep passing workflow_state for detailed wandb logging
|
|
|
|
self.completed_episode_metrics.append(workflow_state.copy())
|
|
if item_id in self.episodes_state: del self.episodes_state[item_id]
|
|
return process_mode_scored_data, []
|
|
|
|
else:
|
|
# --- SERVE MODE: Process one turn, use backlog for continuation ---
|
|
current_turn_messages_serve: List[Message] = []
|
|
user_prompt_str_serve = construct_user_prompt(workflow_state) # Will include retry info if state reflects it
|
|
current_turn_messages_serve.append(Message(role="system", content=SYSTEM_PROMPT))
|
|
current_turn_messages_serve.append(Message(role="user", content=user_prompt_str_serve))
|
|
|
|
llm_response_serve = await self.server.chat_completion(
|
|
messages=current_turn_messages_serve, tools=self.tools, tool_choice="auto", n=1,
|
|
max_tokens=self.config.max_token_length, temperature=0.5
|
|
)
|
|
assistant_message_obj_serve = llm_response_serve.choices[0].message
|
|
assistant_content_serve = assistant_message_obj_serve.content or ""
|
|
assistant_tool_calls_serve = []
|
|
if hasattr(assistant_message_obj_serve, 'tool_calls') and assistant_message_obj_serve.tool_calls:
|
|
assistant_tool_calls_serve = [
|
|
{"id": tc.id, "type": tc.type, "function": {"name": tc.function.name, "arguments": tc.function.arguments}}
|
|
for tc in assistant_message_obj_serve.tool_calls
|
|
]
|
|
current_turn_messages_serve.append(Message(role="assistant", content=assistant_content_serve, tool_calls=assistant_tool_calls_serve if assistant_tool_calls_serve else None))
|
|
|
|
tool_error_for_retry_prompt_serve = None
|
|
if assistant_tool_calls_serve:
|
|
tool_call_request_serve = assistant_tool_calls_serve[0]
|
|
tool_name_serve = tool_call_request_serve["function"]["name"]
|
|
try:
|
|
tool_args_json_str = tool_call_request_serve["function"]["arguments"]
|
|
tool_args_serve = json.loads(tool_args_json_str)
|
|
tool_result_serve = await self._execute_tool(tool_name_serve, tool_args_serve, workflow_state)
|
|
current_turn_messages_serve.append(Message(role="tool", tool_call_id=tool_call_request_serve["id"] , name=tool_name_serve, content=json.dumps(tool_result_serve)))
|
|
workflow_state["last_tool_success"] = tool_result_serve.get("success", False)
|
|
if not workflow_state["last_tool_success"]:
|
|
tool_error_for_retry_prompt_serve = tool_result_serve.get("error", "Tool execution failed.")
|
|
except Exception as e: # Catch JSONDecodeError and others
|
|
error_msg_serve = f"Error processing tool {tool_name_serve}: {str(e)}"
|
|
current_turn_messages_serve.append(Message(role="tool", tool_call_id=tool_call_request_serve["id"], name=tool_name_serve, content=error_msg_serve))
|
|
workflow_state["last_tool_success"] = False
|
|
tool_error_for_retry_prompt_serve = error_msg_serve
|
|
else:
|
|
workflow_state["last_tool_success"] = False
|
|
expected_tool_name_serve = {0:"AF2",1:"RFD",2:"PMPNN",3:"AF2M"}.get(workflow_state["current_internal_step"], "a tool")
|
|
tool_error_for_retry_prompt_serve = f"No tool was called, but {expected_tool_name_serve} was expected."
|
|
|
|
workflow_state["previous_tool_error_message"] = tool_error_for_retry_prompt_serve
|
|
|
|
turn_score_details_serve = self._score_trajectory(current_turn_messages_serve, workflow_state)
|
|
current_turn_reward_serve = turn_score_details_serve.get("overall_reward", 0.0)
|
|
workflow_state["cumulative_reward"] += current_turn_reward_serve
|
|
workflow_state["turn_messages_history"].append(current_turn_messages_serve.copy())
|
|
|
|
tokenization_result_serve = tokenize_for_trainer(
|
|
self.tokenizer, current_turn_messages_serve, include_messages=self.config.include_messages
|
|
)
|
|
scored_data_serve = ScoredDataGroup(
|
|
tokens=[tokenization_result_serve["tokens"]],
|
|
masks=[tokenization_result_serve["masks"]],
|
|
scores=[current_turn_reward_serve],
|
|
messages=[current_turn_messages_serve] if self.config.include_messages else None,
|
|
overrides=[turn_score_details_serve],
|
|
group_overrides={"group_size": 1} # Add group_overrides for serve mode too
|
|
)
|
|
|
|
backlog_items_serve = []
|
|
if workflow_state["last_tool_success"]:
|
|
workflow_state["current_internal_step"] += 1
|
|
workflow_state["retry_count_this_internal_step"] = 0
|
|
workflow_state["previous_tool_error_message"] = None
|
|
else: # Tool failed or was incorrect
|
|
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
|
|
workflow_state["retry_count_this_internal_step"] += 1
|
|
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
|
|
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Max retries reached. Terminating.")
|
|
workflow_state["workflow_complete_flag"] = True
|
|
# else: it will be added to backlog below to retry
|
|
else: # Failure at non-retryable step (should never reach here with MAX_INTERNAL_STEPS = 4)
|
|
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at non-retryable step. Terminating.")
|
|
workflow_state["workflow_complete_flag"] = True
|
|
|
|
if workflow_state["current_internal_step"] < 4 and not workflow_state.get("workflow_complete_flag"):
|
|
# Add to backlog if:
|
|
# 1. Last tool was successful (to move to next step)
|
|
# OR
|
|
# 2. Last tool failed, current step is <= 3, and we haven't hit max retries (to retry current step)
|
|
should_add_to_backlog = False
|
|
if workflow_state["last_tool_success"]:
|
|
should_add_to_backlog = True
|
|
elif workflow_state["current_internal_step"] <= 3 and \
|
|
workflow_state["retry_count_this_internal_step"] <= self.config.max_retries_per_internal_step:
|
|
should_add_to_backlog = True
|
|
|
|
if should_add_to_backlog:
|
|
backlog_items_serve.append(item_id)
|
|
else: # Condition for adding to backlog not met
|
|
workflow_state["workflow_complete_flag"] = True # Mark as complete (due to failure beyond retries)
|
|
logger.info(f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. Internal step: {workflow_state['current_internal_step']}")
|
|
|
|
if workflow_state.get("workflow_complete_flag"): # If flag was set either by reaching step 4 or by retry logic
|
|
# For completed workflows in serve mode, use direct logging with workflow_state
|
|
# before it gets deleted
|
|
if item_id in self.episodes_state:
|
|
# Use direct workflow_state logging for maximum detail
|
|
await self.add_rollouts_for_wandb(data_for_log=self.episodes_state[item_id].copy())
|
|
self.completed_episode_metrics.append(self.episodes_state[item_id].copy())
|
|
del self.episodes_state[item_id]
|
|
# Note: We don't need to call add_rollouts_for_wandb with scored_data_serve here
|
|
# BaseEnv.handle_send_to_api will call it automatically with the scored_data_serve
|
|
# that we return
|
|
|
|
return scored_data_serve, backlog_items_serve
|
|
|
|
def _score_trajectory(self, turn_messages: List[Message], workflow_state: Dict) -> Dict[str, float]:
|
|
"""
|
|
Scores a single turn's trajectory based on the specified reward logic.
|
|
- Steps 0-2: Format reward (0.2 for correct & successful tool call, 0 otherwise).
|
|
- Step 3 (AF2-Multimer): Reward based on pLDDT.
|
|
"""
|
|
detailed_scores = {
|
|
"overall_reward": 0.0,
|
|
"raw_plddt": 0.0,
|
|
}
|
|
|
|
internal_step = workflow_state.get("current_internal_step")
|
|
last_tool_success = workflow_state.get("last_tool_success", False)
|
|
|
|
# ***** MODIFIED HERE *****
|
|
assistant_msg_dict = next((m for m in reversed(turn_messages) if m.get("role") == "assistant"), None)
|
|
|
|
expected_tool_for_step = {
|
|
0: "predict_target_structure_alphafold2",
|
|
1: "design_binder_backbone_rfdiffusion",
|
|
2: "design_binder_sequence_proteinmpnn",
|
|
3: "evaluate_binder_complex_alphafold2_multimer"
|
|
}.get(internal_step)
|
|
|
|
called_tool_name = None
|
|
# ***** AND HERE *****
|
|
if assistant_msg_dict and assistant_msg_dict.get("tool_calls"):
|
|
tool_calls_list = assistant_msg_dict.get("tool_calls")
|
|
if tool_calls_list and isinstance(tool_calls_list, list) and len(tool_calls_list) > 0:
|
|
# Assuming tool_calls_list[0] is a dict as per your Message structure
|
|
function_call_dict = tool_calls_list[0].get("function")
|
|
if function_call_dict and isinstance(function_call_dict, dict):
|
|
called_tool_name = function_call_dict.get("name")
|
|
|
|
# --- Scoring for Steps 0, 1, 2 (Internal Steps before AF2-Multimer) ---
|
|
if internal_step < 3:
|
|
if last_tool_success and called_tool_name == expected_tool_for_step:
|
|
detailed_scores["overall_reward"] = 0.2
|
|
logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Correct tool '{called_tool_name}' used successfully. Reward: 0.2")
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
if not last_tool_success and called_tool_name:
|
|
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Tool '{called_tool_name}' execution failed. Reward: 0.0")
|
|
elif called_tool_name != expected_tool_for_step:
|
|
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Incorrect tool '{called_tool_name}' used (expected '{expected_tool_for_step}'). Reward: 0.0")
|
|
elif not called_tool_name and expected_tool_for_step:
|
|
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: No tool called, but expected '{expected_tool_for_step}'. Reward: 0.0")
|
|
|
|
# --- Scoring for Step 3 (AF2-Multimer evaluation) ---
|
|
elif internal_step == 3:
|
|
if workflow_state.get("complex_evaluated") and last_tool_success and called_tool_name == expected_tool_for_step:
|
|
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
|
|
detailed_scores["raw_plddt"] = plddt
|
|
|
|
if plddt > 90.0:
|
|
detailed_scores["overall_reward"] = 1.0
|
|
elif plddt > 50.0:
|
|
detailed_scores["overall_reward"] = 0.0 + (plddt - 50.0) * (1.0 - 0.0) / (90.0 - 50.0)
|
|
detailed_scores["overall_reward"] = max(0.0, min(detailed_scores["overall_reward"], 1.0))
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
|
|
logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}")
|
|
else:
|
|
detailed_scores["overall_reward"] = 0.0
|
|
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): Evaluation failed or wrong tool. Reward: -0.5. Last tool success: {last_tool_success}, Called: {called_tool_name}")
|
|
|
|
else:
|
|
logger.error(f"Workflow {workflow_state['item_id']}: Invalid internal_step {internal_step} in scoring.")
|
|
detailed_scores["overall_reward"] = -1.0
|
|
|
|
return detailed_scores
|
|
|
|
async def postprocess_histories(
|
|
self, trajectories: Optional[ScoredDataGroup]
|
|
) -> Optional[ScoredDataGroup]:
|
|
"""
|
|
Post-processes a ScoredDataGroup for a single turn.
|
|
Can be used for final adjustments or filtering if needed.
|
|
"""
|
|
# Just pass through trajectories without modification
|
|
return trajectories
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Evaluate the environment's performance.
|
|
This method is called periodically by the BaseEnv.env_manager.
|
|
For BinderBenchEnv, it will aggregate metrics from completed workflows.
|
|
"""
|
|
logger.info(f"Running evaluation for {self.name}...")
|
|
if not self.completed_episode_metrics:
|
|
logger.info("No completed episodes to evaluate since last evaluation.")
|
|
self.eval_metrics = [] # Ensure eval_metrics is an empty list if no new data
|
|
if self.config.use_wandb:
|
|
await self.wandb_log({}) # Log that no eval data was present this cycle
|
|
return
|
|
|
|
# --- Metrics Calculation ---
|
|
# These metrics are based on the episodes completed *since the last evaluation*
|
|
# or since the start if this is the first evaluation.
|
|
plddts, ptms, iptms, cumulative_rewards, workflow_successes = [], [], [], [], []
|
|
|
|
# Use a copy of the buffer for this evaluation cycle
|
|
current_eval_episodes = self.completed_episode_metrics.copy()
|
|
# self.completed_episode_metrics.clear() # Clear the main buffer for the next cycle
|
|
|
|
for ep_state in current_eval_episodes:
|
|
if ep_state.get("complex_evaluated") and ep_state.get("last_tool_success"):
|
|
plddts.append(ep_state.get("af2_multimer_plddt", 0.0))
|
|
# ptms.append(ep_state.get("af2_multimer_ptm", 0.0)) # You set these to 0.0
|
|
# iptms.append(ep_state.get("af2_multimer_iptm", 0.0))# You set these to 0.0
|
|
workflow_successes.append(1.0)
|
|
else:
|
|
workflow_successes.append(0.0)
|
|
cumulative_rewards.append(ep_state.get("cumulative_reward", 0.0))
|
|
|
|
self.eval_metrics = [] # Reset class member for current evaluation results
|
|
if plddts:
|
|
self.eval_metrics.append(("eval/avg_plddt", sum(plddts) / len(plddts)))
|
|
# if ptms: # Not currently being populated with real values
|
|
# self.eval_metrics.append(("eval/avg_ptm", sum(ptms) / len(ptms)))
|
|
# if iptms: # Not currently being populated with real values
|
|
# self.eval_metrics.append(("eval/avg_iptm", sum(iptms) / len(iptms)))
|
|
if cumulative_rewards:
|
|
self.eval_metrics.append(("eval/avg_cumulative_reward", sum(cumulative_rewards) / len(cumulative_rewards)))
|
|
if workflow_successes:
|
|
self.eval_metrics.append(("eval/workflow_success_rate", sum(workflow_successes) / len(workflow_successes)))
|
|
|
|
logger.info(f"Evaluation complete. Calculated metrics: {self.eval_metrics}")
|
|
|
|
# Log to WandB immediately after evaluation if enabled
|
|
if self.config.use_wandb:
|
|
# self.wandb_log will pick up self.eval_metrics
|
|
await self.wandb_log({})
|
|
|
|
# It's important to clear self.completed_episode_metrics *after* they've been processed
|
|
# for this eval cycle to avoid re-evaluating old data.
|
|
# If evaluation is meant to be on *all* completed episodes ever, don't clear.
|
|
# Typically, eval is on data since last eval or a fixed test set.
|
|
# Given it's populated by collect_trajectories, clearing seems appropriate for periodic eval.
|
|
self.completed_episode_metrics.clear()
|
|
|
|
async def add_rollouts_for_wandb(self,
|
|
scored_data_group: ScoredDataGroup = None, # From BaseEnv
|
|
item_id: Item = None, # From BaseEnv
|
|
data_for_log: Dict = None): # Our custom param for direct workflow_state logging
|
|
"""Adds a workflow summary to the wandb rollout buffer.
|
|
|
|
This method has two modes of operation:
|
|
1. Direct logging with workflow_state (preferred for detailed logging):
|
|
- Called from within collect_trajectories with data_for_log=workflow_state.copy()
|
|
- This provides maximum detail for logging
|
|
|
|
2. BaseEnv compatibility mode:
|
|
- Called from BaseEnv.handle_send_to_api with scored_data_group and item_id
|
|
- Used automatically by the framework
|
|
- May have less detail if workflow_state was already deleted
|
|
|
|
Args:
|
|
scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv)
|
|
item_id: The item identifier, which is the key to our episodes_state (from BaseEnv)
|
|
data_for_log: Direct workflow state to log (our custom parameter for direct logging)
|
|
"""
|
|
if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"):
|
|
# Ensure rollouts_for_wandb exists
|
|
if not hasattr(self, "rollouts_for_wandb"):
|
|
self.rollouts_for_wandb = []
|
|
|
|
# Determine the workflow state to use
|
|
workflow_state = None
|
|
|
|
# Case 1: Direct workflow state provided (most detailed)
|
|
if data_for_log is not None and isinstance(data_for_log, dict):
|
|
workflow_state = data_for_log
|
|
# Extract item_id from data_for_log if needed
|
|
if item_id is None and "item_id" in workflow_state:
|
|
item_id = workflow_state["item_id"]
|
|
|
|
# Case 2: Try to get workflow_state from episodes_state (if not already deleted)
|
|
elif item_id is not None and item_id in self.episodes_state:
|
|
workflow_state = self.episodes_state[item_id]
|
|
|
|
# Case 3: No usable state - early return with a debug log (not warning)
|
|
# This happens in BaseEnv.handle_send_to_api after workflow is already completed
|
|
if workflow_state is None:
|
|
# This is expected in BaseEnv's call after workflow_state is deleted, so use debug level
|
|
logger.debug(f"No workflow_state available for WandB logging (item_id={item_id})")
|
|
return
|
|
|
|
# Customize what you want to see in the WandB table for a completed workflow
|
|
# Handle cases where values might be None
|
|
target_seq = workflow_state.get("target_sequence", "N/A")
|
|
|
|
# Handle designed_binder which might be None
|
|
designed_binder = workflow_state.get("designed_binder_sequence", "N/A")
|
|
if designed_binder is None:
|
|
designed_binder = "N/A"
|
|
|
|
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
|
|
iptm = workflow_state.get("af2_multimer_iptm", 0.0) # Even if 0, log it
|
|
cumulative_reward = workflow_state.get("cumulative_reward", 0.0)
|
|
|
|
# For messages, maybe just the final assistant message that led to AF2M or a summary
|
|
# Storing all turn_messages_history can make the table huge.
|
|
# Let's take the last turn's messages for this example.
|
|
last_turn_messages_str = "No messages."
|
|
try:
|
|
if workflow_state.get("turn_messages_history") and len(workflow_state["turn_messages_history"]) > 0:
|
|
last_turn_convo = workflow_state["turn_messages_history"][-1]
|
|
last_turn_messages_str = "\n---\n".join(
|
|
[f"{msg.get('role', 'unknown')}: {str(msg.get('content', ''))[:200]}..." for msg in last_turn_convo]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing messages for WandB: {e}")
|
|
last_turn_messages_str = "Error processing messages"
|
|
|
|
# Safely truncate strings
|
|
target_preview = target_seq[:30] + "..." if isinstance(target_seq, str) and len(target_seq) > 30 else target_seq
|
|
|
|
if designed_binder == "N/A" or designed_binder is None:
|
|
binder_preview = "N/A"
|
|
else:
|
|
binder_preview = designed_binder[:30] + "..." if len(str(designed_binder)) > 30 else designed_binder
|
|
|
|
# Use item_id from workflow_state if still None
|
|
if item_id is None:
|
|
item_id = workflow_state.get("item_id", "unknown-id")
|
|
|
|
# Add to rollouts buffer
|
|
self.rollouts_for_wandb.append(
|
|
( # This tuple structure will be used by create_rollout_table
|
|
str(item_id), # Ensure item_id is a string
|
|
target_preview,
|
|
binder_preview,
|
|
f"{plddt:.2f}",
|
|
f"{iptm:.2f}",
|
|
f"{cumulative_reward:.3f}",
|
|
last_turn_messages_str # Or a link to the full JSONL entry, or more details
|
|
)
|
|
)
|
|
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
self.rollouts_for_wandb.pop(0)
|
|
|
|
async def create_rollout_table(self, wandb_metrics: Dict) -> Dict:
|
|
"""Creates a wandb.Table from the buffered rollouts."""
|
|
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
|
|
# Define columns based on what add_rollouts_for_wandb appends
|
|
columns = ["Item ID", "Target (Preview)", "Designed Binder (Preview)",
|
|
"Final pLDDT", "Final ipTM", "Cumulative Reward", "Last Turn Messages"]
|
|
table = wandb.Table(columns=columns)
|
|
for rollout_tuple in self.rollouts_for_wandb:
|
|
table.add_data(*rollout_tuple) # Unpack the tuple
|
|
|
|
# Use a unique key for the table, prepended by wandb_prepend
|
|
table_key = f"env_rollouts/{self.wandb_prepend}/completed_workflows"
|
|
if self.wandb_prepend is None and hasattr(self, "name"): # Fallback if wandb_prepend not set yet
|
|
table_key = f"env_rollouts/{self.name}/completed_workflows"
|
|
|
|
wandb_metrics[table_key] = table
|
|
self.rollouts_for_wandb.clear()
|
|
return wandb_metrics
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Create and add the rollout table to wandb_metrics
|
|
if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb:
|
|
wandb_metrics = await self.create_rollout_table(wandb_metrics)
|
|
|
|
# Add any training-time aggregated metrics (not from self.completed_episode_metrics,
|
|
# as that's now handled by evaluate for eval-specific logging)
|
|
# For example, if you had a buffer for per-turn scores during training rollouts:
|
|
# if self.per_turn_score_buffer:
|
|
# wandb_metrics[f"train/{self.wandb_prepend}/avg_turn_reward"] = sum(self.per_turn_score_buffer) / len(self.per_turn_score_buffer)
|
|
# self.per_turn_score_buffer.clear()
|
|
|
|
# The self.eval_metrics (populated by evaluate()) will be picked up by super().wandb_log()
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
if __name__ == "__main__":
|
|
BinderBenchEnv.cli()
|