rfdiffusion fix

This commit is contained in:
based-tachikoma 2025-05-19 19:42:48 -07:00
parent 4d9bec44c6
commit de9dfff221
8 changed files with 1253 additions and 104 deletions

View file

@ -3,9 +3,10 @@ import json
import logging
import os
import random
import re
import uuid
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict
from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict, Set
import yaml
import wandb # Add import for wandb
@ -89,6 +90,85 @@ def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str
return ds
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
"""
Parses PDB content to extract detailed information for each chain.
Returns:
A tuple containing:
- chain_details (Dict[str, Dict[str, int]]):
A dictionary where keys are chain IDs (e.g., "A").
Each value is another dictionary:
{
"min_residue": int, # Smallest residue number found for this chain
"max_residue": int, # Largest residue number found for this chain
"length": int # Count of unique C-alpha atoms (residues) in this chain
}
- pdb_preview (str): A string preview of the PDB content.
"""
chain_info_temp: Dict[str, Dict[str, Union[Set[int], int]]] = {} # Stores residue numbers and CA count for each chain
atom_lines = []
header_lines = []
# First pass: Collect all residue numbers and CA atoms per chain
for line in pdb_content.splitlines():
if line.startswith("ATOM"): # Consider only ATOM records for canonical residues
atom_lines.append(line)
chain_id = line[21:22].strip()
if not chain_id:
chain_id = " " # Default for blank chain ID, consider how RFDiffusion handles this
atom_name = line[12:16].strip()
try:
residue_num = int(line[22:26].strip())
if chain_id not in chain_info_temp:
chain_info_temp[chain_id] = {"residues": set(), "ca_count": 0}
chain_info_temp[chain_id]["residues"].add(residue_num)
if atom_name == "CA":
chain_info_temp[chain_id]["ca_count"] += 1
except ValueError:
logger.warning(f"Could not parse residue number from PDB line: {line}")
continue
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
header_lines.append(line)
# Second pass: Calculate min, max, and length from collected data
chain_details: Dict[str, Dict[str, int]] = {}
for chain_id, data in chain_info_temp.items():
if data["residues"]: # Only process if residues were found
min_res = min(data["residues"])
max_res = max(data["residues"])
# Length can be defined in two ways:
# 1. max_res - min_res + 1 (if contiguous numbering)
# 2. Count of unique residues (safer for gaps, but AF2 is usually contiguous)
# 3. Count of C-alpha atoms (good proxy for actual modeled residues)
# Let's use ca_count as it reflects actual modeled residues.
# If ca_count is 0 but residues were found (e.g. only HETATMs), this needs thought.
# For now, prioritizing ca_count.
length = data["ca_count"] if data["ca_count"] > 0 else len(data["residues"])
chain_details[chain_id] = {
"min_residue": min_res,
"max_residue": max_res,
"length": length
}
else:
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
# Construct PDB preview
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
remaining_preview_lines = preview_lines - len(preview_str_parts)
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
pdb_preview = "\n".join(preview_str_parts)
if len(pdb_content.splitlines()) > preview_lines:
pdb_preview += "\n..."
return chain_details, pdb_preview
def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, int], str]:
chain_lengths = {}
current_chain_id = None
@ -130,11 +210,10 @@ def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10)
return chain_lengths, pdb_preview
def construct_user_prompt(state: dict) -> str: # state is an item from self.episodes_state
internal_step = state.get("current_internal_step", 0) # Use internal step from our state
internal_step = state.get("current_internal_step", 0)
target_sequence = state.get("target_sequence")
user_prompt_str = ""
# Base prompt construction (your existing logic)
if internal_step == 0: # Step 1: Predict Target Structure (AlphaFold2)
user_prompt_str = (
f"The target protein sequence is: {target_sequence}. "
@ -142,52 +221,85 @@ def construct_user_prompt(state: dict) -> str: # state is an item from self.epis
"You must provide the 'sequence' argument."
)
elif internal_step == 1: # Step 2: Design Binder Backbone (RFDiffusion)
# Use the stored preview and chain info
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
chain_info = state.get("target_chain_info", {})
chain_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in chain_info.items()])
if not chain_info_str: chain_info_str = "Chain information not available."
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") # Can keep preview for general context
# --- NEW CHAIN INFO FORMATTING ---
chain_details = state.get("target_chain_details", {}) # Get the new detailed info
if chain_details:
chain_info_parts = []
for chain_id, details in chain_details.items():
min_r = details.get('min_residue', 'N/A')
max_r = details.get('max_residue', 'N/A')
l = details.get('length', 'N/A')
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
chain_info_str = "\n- ".join(chain_info_parts)
if chain_info_str:
chain_info_str = "- " + chain_info_str # Add leading bullet for the first item
else:
chain_info_str = "Chain information not available or PDB not yet processed."
# --- END NEW CHAIN INFO FORMATTING ---
user_prompt_str = (
f"The 3D structure of the target protein has been predicted. Target PDB preview:\n{target_pdb_preview}\n"
f"Target chain information: {chain_info_str}.\n"
"Now, design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You need to specify 'contigs'. Contigs define segments from the target PDB (e.g., 'A1-100' means residues 1-100 of target chain A) "
"and segments for the new binder (e.g., '/0 50-70' means generate a new chain of length 50 to 70 residues). "
"A full example for a 60-residue binder for Chain A of the target (if Chain A has 100 residues): 'A1-100/0 60'. "
"Ensure any target residue numbers in contigs are within the valid range for the respective chain. "
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring they exist on the target."
f"The 3D structure of the target protein has been predicted.\n"
# Optional: f"Target PDB preview:\n{target_pdb_preview}\n\n"
f"Target Protein Chain Details:\n{chain_info_str}\n\n" # Use the detailed chain info
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
"Examples:\n"
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
)
elif internal_step == 2: # Step 3: Design Binder Sequence (ProteinMPNN)
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.")
binder_chain_info = state.get("binder_chain_info", {}) # Info about the binder backbone itself
binder_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in binder_chain_info.items()])
if not binder_info_str: binder_info_str = "Binder chain information not available."
# Get detailed binder chain information using the get_pdb_chain_details function
binder_pdb_content = state.get("binder_backbone_pdb_content")
if binder_pdb_content:
binder_chain_details, binder_pdb_preview = get_pdb_chain_details(binder_pdb_content)
binder_chain_info_str = "\n- ".join([f"Chain {cID} (Residues: {d.get('min_residue','N/A')}-{d.get('max_residue','N/A')}, Length: {d.get('length','N/A')})" for cID, d in binder_chain_details.items()])
if binder_chain_info_str: binder_chain_info_str = "- " + binder_chain_info_str
else:
binder_pdb_preview = "Binder PDB preview not available."
binder_chain_info_str = "Binder chain information not available."
user_prompt_str = (
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
f"Binder chain information: {binder_info_str}.\n"
f"Binder chain information:\n{binder_chain_info_str}.\n"
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
)
elif internal_step == 3: # Step 4: Evaluate Complex (AlphaFold2-Multimer)
designed_binder_seq = state.get("designed_binder_sequence", "Not yet available")
designed_binder_seq_data = state.get("designed_binder_sequence") # This is List[str]
binder_display_str = "Not available"
if isinstance(designed_binder_seq_data, list) and designed_binder_seq_data:
if len(designed_binder_seq_data) == 1:
binder_display_str = designed_binder_seq_data[0]
else:
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
for i, s in enumerate(designed_binder_seq_data)])
elif isinstance(designed_binder_seq_data, str): # Should not happen with new PMPNN parsing
binder_display_str = designed_binder_seq_data
user_prompt_str = (
f"A binder sequence has been designed: {designed_binder_seq}. "
f"The original target sequence was: {target_sequence}.\n" # Remind LLM of original target
"Finally, evaluate the binding complex of the original target protein and this designed binder using the "
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
f"The original target sequence was: {target_sequence[:60]}...\n"
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
"'evaluate_binder_complex_alphafold2_multimer' tool. "
"You can optionally specify 'relax_prediction' (default is True)."
)
else: # Workflow complete or error
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
# ***** ADD RETRY PREFIX IF APPLICABLE *****
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: # For all steps
# Retry logic should remain the same:
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
retry_prefix = "Your previous attempt at this step was not successful. "
if state.get("previous_tool_error_message"):
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
retry_prefix += "Please review the requirements and try again to correctly use the expected tool.\n\n"
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
user_prompt_str = retry_prefix + user_prompt_str
return user_prompt_str
@ -204,7 +316,7 @@ class BinderBenchConfig(BaseEnvConfig):
nim_api_base_url: str = Field("https://health.api.nvidia.com/v1", description="NIM API base URL")
api_timeout: int = Field(1800, description="Timeout for NIM API calls") # Increased default
polling_interval: int = Field(30, description="Polling interval for NIM jobs") # Increased default
output_dir: str = Field("binder_outputs", description="Directory to save PDBs, etc.")
output_dir: str = Field(default=str(Path(__file__).parent / "outputs"), description="Directory to save PDBs, etc.")
debug_protein_design_calls: bool = Field(False, description="Enable debug mode for NIM protein API calls, returning mock data.")
max_retries_per_internal_step: int = Field(100, description="Max retries for a failed tool call within a workflow step (0 means no retries).") # Default to 1 retry (2 attempts total)
# Dataset specific
@ -349,8 +461,8 @@ class BinderBenchEnv(BaseEnv):
# Create a dummy PDB content if file not found to prevent downstream errors, but log severe warning
pdb_content = "HEADER DUMMY PDB FOR DEBUG - TARGET.PDB NOT FOUND\nATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\nTER\nEND\n"
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True # Mark as "predicted" for workflow to proceed
return {"success": False, "error": f"Debug mode error: {fixed_pdb_path} not found.", "target_pdb_preview": pdb_preview}
@ -360,15 +472,15 @@ class BinderBenchEnv(BaseEnv):
pdb_content = f.read()
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True
# Optionally, save a copy of this mock PDB to the usual output location for consistency
debug_output_pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2_DEBUG.pdb"
with open(debug_output_pdb_path, "w") as f: f.write(pdb_content)
logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain info: {chain_lengths}")
logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain details: {chain_details}")
return {"success": True, "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "target_pdb_preview": pdb_preview}
except Exception as e:
@ -393,13 +505,13 @@ class BinderBenchEnv(BaseEnv):
if api_result and isinstance(api_result, list) and api_result[0]:
pdb_content = api_result[0]
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content)
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True
pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2.pdb"
with open(pdb_path, "w") as f: f.write(pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain info: {chain_lengths}")
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
return {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview}
else:
logger.error(f"Workflow {item_id}: AlphaFold2 call failed or returned unexpected data: {api_result}")
@ -434,47 +546,125 @@ class BinderBenchEnv(BaseEnv):
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
binder_pdb = workflow_state.get("binder_backbone_pdb_content")
if not binder_pdb: return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
if not binder_pdb:
return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
sampling_temp = args.get("sampling_temp", [0.1]) # Default if not provided
sampling_temp_list = args.get("sampling_temp", [0.1])
item_id = workflow_state["item_id"]
api_result = await call_proteinmpnn(
input_pdb=binder_pdb, api_key=self.config.nim_api_key,
sampling_temp=sampling_temp,
sampling_temp=sampling_temp_list,
timeout=self.config.api_timeout, polling_interval=self.config.polling_interval
# Add other PMPNN specific params
)
if api_result and "mfasta" in api_result:
fasta_content = api_result["mfasta"]
designed_sequence = ""
for line in fasta_content.splitlines():
if not line.startswith(">") and line.strip():
designed_sequence = line.strip()
break
if not designed_sequence:
return {"success": False, "error": "Could not parse sequence from ProteinMPNN FASTA."}
workflow_state["designed_binder_sequence"] = designed_sequence
workflow_state["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta"
with open(fasta_path, "w") as f: f.write(fasta_content)
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}")
# NO LONGER INCREMENT current_internal_step HERE - collect_trajectories will handle this
return {"success": True, "message": "ProteinMPNN complete.", "designed_binder_sequence": designed_sequence}
else:
if not (api_result and "mfasta" in api_result):
logger.error(f"Workflow {item_id}: ProteinMPNN call failed or returned unexpected data: {api_result}")
return {"success": False, "error": "ProteinMPNN failed."}
return {"success": False, "error": "ProteinMPNN call failed or no mfasta in result."}
fasta_content = api_result["mfasta"]
logger.info(f"CRITICAL_DEBUG: ProteinMPNN raw mfasta output for item {item_id}:\n{fasta_content}")
# --- FASTA Parsing Logic to find best sequence by global_score ---
entries: List[Tuple[float, str, str]] = [] # (global_score, header, sequence_line)
current_header = None
current_sequence_parts: List[str] = []
for line_content in fasta_content.splitlines():
line = line_content.strip()
if not line: continue
if line.startswith(">"):
if current_header and current_sequence_parts: # Process previous entry
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
entries.append((global_score, current_header, full_sequence_line))
current_header = line
current_sequence_parts = []
else:
current_sequence_parts.append(line)
if current_header and current_sequence_parts: # Process the last entry
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
entries.append((global_score, current_header, full_sequence_line))
if not entries:
logger.error(f"Workflow {item_id}: No sequences found in ProteinMPNN mfasta output.")
return {"success": False, "error": "No sequences parsed from ProteinMPNN mfasta."}
# Sort by global_score (descending) and select the best
entries.sort(key=lambda x: x[0], reverse=True)
best_global_score, best_header, best_full_sequence_line = entries[0]
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}'")
logger.info(f"Workflow {item_id}: Corresponding sequence line: '{best_full_sequence_line}'")
# Split the selected sequence line by '/' to handle potential chainbreaks
parsed_binder_chains: List[str] = [
seq_part.strip() for seq_part in best_full_sequence_line.split('/') if seq_part.strip()
]
if not parsed_binder_chains:
error_msg = f"Splitting best PMPNN sequence ('{best_full_sequence_line}') by '/' yielded no valid chains."
logger.error(f"Workflow {item_id}: {error_msg}")
return {"success": False, "error": error_msg}
# Validate each parsed chain (ensure they are valid protein sequences)
for seq_idx, seq_part in enumerate(parsed_binder_chains):
if not (seq_part and seq_part.isalpha() and seq_part.isupper()):
error_msg = f"Parsed binder chain {seq_idx+1} ('{seq_part[:30]}...') contains invalid characters or is empty."
logger.error(f"Workflow {item_id}: {error_msg}")
return {"success": False, "error": error_msg}
workflow_state["designed_binder_sequence"] = parsed_binder_chains # Store as List[str]
workflow_state["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta"
with open(fasta_path, "w") as f: f.write(fasta_content) # Save original full FASTA
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
preview = parsed_binder_chains[0][:60] + "..." if len(parsed_binder_chains[0]) > 60 else parsed_binder_chains[0]
if len(parsed_binder_chains) > 1:
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
return {
"success": True,
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
"designed_binder_sequence_list": parsed_binder_chains,
"designed_binder_sequence_preview": preview
}
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
target_seq = workflow_state.get("target_sequence")
binder_seq = workflow_state.get("designed_binder_sequence")
if not target_seq or not binder_seq:
return {"success": False, "error": "Missing target or binder sequence for AlphaFold2-Multimer."}
binder_seq_data = workflow_state.get("designed_binder_sequence")
if not target_seq:
return {"success": False, "error": "Missing target sequence for AlphaFold2-Multimer."}
if not binder_seq_data:
return {"success": False, "error": "Missing binder sequence for AlphaFold2-Multimer."}
# Handle binder_seq_data which could now be either a List[str] or a single string (for backward compatibility)
binder_sequences = []
if isinstance(binder_seq_data, list):
binder_sequences = binder_seq_data
elif isinstance(binder_seq_data, str):
binder_sequences = [binder_seq_data] # Wrap in list
else:
return {"success": False, "error": f"Unexpected type for binder sequence: {type(binder_seq_data)}"}
if not binder_sequences:
return {"success": False, "error": "Empty binder sequence list for AlphaFold2-Multimer."}
relax = args.get("relax_prediction", True)
item_id = workflow_state["item_id"]
logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and binder (len {len(binder_seq)}).")
# Log all sequences for debugging
total_binder_length = sum(len(seq) for seq in binder_sequences)
logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and {len(binder_sequences)} binder chain(s) (total len {total_binder_length}).")
# Check if in debug mode
if self.config.debug_protein_design_calls:
@ -515,12 +705,18 @@ class BinderBenchEnv(BaseEnv):
}
# Non-debug mode: proceed with actual API call
# Create a list with target sequence as first element, followed by all binder sequences
all_sequences = [target_seq] + binder_sequences
logger.info(f"Workflow {item_id}: Calling AlphaFold2-Multimer with {len(all_sequences)} sequences: "
f"1 target ({len(target_seq)} aa) + {len(binder_sequences)} binder chains.")
api_result = await call_alphafold2_multimer(
sequences=[target_seq, binder_seq],
sequences=all_sequences, # Pass all sequences as a flat list: [target_seq, binder_seq1, binder_seq2, ...]
api_key=self.config.nim_api_key,
relax_prediction=relax,
timeout=self.config.api_timeout, # Pass configured timeout
polling_interval=self.config.polling_interval # Pass configured polling interval
timeout=self.config.api_timeout,
polling_interval=self.config.polling_interval
)
if api_result and api_result.get("structures") and len(api_result["structures"]) > 0:
@ -624,6 +820,7 @@ class BinderBenchEnv(BaseEnv):
"target_sequence": target_sequence,
"ground_truth_binder_sequence": ground_truth_binder, # Store for final evaluation
"target_pdb_content": None,
"target_chain_details": None, # Store detailed chain information
"binder_backbone_pdb_content": None,
"designed_binder_sequence": None,
"complex_pdb_content_path": None, # Path to AF2-Multimer output
@ -758,7 +955,7 @@ class BinderBenchEnv(BaseEnv):
workflow_state["retry_count_this_internal_step"] = 0 # Reset for new step
workflow_state["previous_tool_error_message"] = None
else: # Tool call failed or was incorrect
if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
workflow_state["retry_count_this_internal_step"] += 1
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Max retries ({self.config.max_retries_per_internal_step}) reached. Terminating workflow for this item.")
@ -767,8 +964,8 @@ class BinderBenchEnv(BaseEnv):
else:
logger.info(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failed, attempt {workflow_state['retry_count_this_internal_step']}. Retrying same step.")
# Loop continues, construct_user_prompt will use retry info
else: # Failure at step 3 (AF2M) is terminal for the workflow
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at critical AF2-Multimer step. Terminating workflow.")
else: # Should never reach here with MAX_INTERNAL_STEPS = 4, but keeping for safety
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at non-retryable step. Terminating workflow.")
workflow_state["workflow_complete_flag"] = True
break # Exit the internal while loop
@ -778,24 +975,101 @@ class BinderBenchEnv(BaseEnv):
# No break here, loop condition will handle it
# After the internal while loop (for process mode)
if not all_turns_data_for_jsonl: return None, []
last_turn_data = all_turns_data_for_jsonl[-1]
aggregated_messages = [turn_data["messages_this_turn"] for turn_data in all_turns_data_for_jsonl]
aggregated_overrides = [turn_data["overrides_this_turn"] for turn_data in all_turns_data_for_jsonl]
final_reward_for_group = workflow_state.get("cumulative_reward", 0.0)
if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"):
final_reward_for_group = last_turn_data["overrides_this_turn"].get("overall_reward", 0.0)
if not all_turns_data_for_jsonl:
logger.warning(f"Workflow {item_id} in process mode: No turn data collected.")
return None, []
# --- Start of Fix for jsonl2html ---
html_compatible_messages: List[str] = []
html_compatible_scores: List[float] = []
# `overrides_for_jsonl` will store the detailed scoring dict for each turn,
# matching the structure of `html_compatible_messages` and `html_compatible_scores`.
overrides_for_jsonl: List[Dict[str, Any]] = []
for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl):
# Format messages for this turn into a single readable string
turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"]
if turn_data.get("messages_this_turn"):
for msg_obj in turn_data["messages_this_turn"]:
content_str = str(msg_obj.get("content", "[No Content]"))
if msg_obj.get("tool_calls"):
try:
tool_calls_str = json.dumps(msg_obj.get("tool_calls"), indent=2)
content_str += f"\nTool Calls:\n{tool_calls_str}"
except TypeError: # Handle non-serializable content if any
content_str += f"\nTool Calls: [Error serializing tool_calls]"
turn_str_parts.append(f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}")
else:
turn_str_parts.append("No messages recorded for this turn.")
html_compatible_messages.append("\n\n".join(turn_str_parts))
# Get the score for this specific turn
turn_score = turn_data.get("overrides_this_turn", {}).get("overall_reward", 0.0)
html_compatible_scores.append(turn_score)
# Add the detailed scoring dictionary for this turn
overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {}))
final_workflow_reward = workflow_state.get("cumulative_reward", 0.0)
# If the complex was evaluated successfully, the last turn's reward is the final one.
if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"):
final_workflow_reward = all_turns_data_for_jsonl[-1].get("overrides_this_turn", {}).get("overall_reward", 0.0)
# For the ScoredDataGroup that will be handled by BaseEnv
# We need tokens and masks for each "message" (turn) if we want BaseEnv to consider it valid
# For simplicity, we can just repeat the last turn's tokens/masks, or use placeholders
# if actual per-turn tokens aren't critical for the JSONL's main purpose (which is visualization via messages/scores).
# Let's create placeholder tokens/masks if full history isn't needed by the trainer for process_mode.
# Or, better, store actual tokens for each turn if available.
all_tokens_per_turn = [turn_data["tokens_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("tokens_this_turn")]
all_masks_per_turn = [turn_data["masks_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("masks_this_turn")]
# Ensure all_tokens_per_turn and all_masks_per_turn have same length as html_compatible_messages
# If some turns didn't produce tokens (e.g. error), we might need to pad or handle.
# For now, assuming all_turns_data_for_jsonl consistently has tokens/masks for each entry that contributes to html_compatible_messages.
if len(all_tokens_per_turn) != len(html_compatible_messages):
logger.error(f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic.")
# Fallback: repeat last turn's tokens if necessary, though this isn't ideal.
if all_turns_data_for_jsonl and all_tokens_per_turn:
last_tokens = all_tokens_per_turn[-1]
last_masks = all_masks_per_turn[-1]
all_tokens_per_turn = [last_tokens] * len(html_compatible_messages)
all_masks_per_turn = [last_masks] * len(html_compatible_messages)
else: # No token data at all
all_tokens_per_turn = [[] for _ in html_compatible_messages]
all_masks_per_turn = [[] for _ in html_compatible_messages]
# This is the ScoredDataGroup that will be written to JSONL by BaseEnv
process_mode_scored_data = ScoredDataGroup(
tokens=[last_turn_data["tokens_this_turn"]],
masks=[last_turn_data["masks_this_turn"]],
scores=[final_reward_for_group],
messages=aggregated_messages if self.config.include_messages else None,
overrides=aggregated_overrides,
group_overrides={"group_size": 1} # ***** THIS IS THE DEFINITIVE FIX FOR THIS PART *****
tokens=all_tokens_per_turn, # List of token lists, one for each turn
masks=all_masks_per_turn, # List of mask lists, one for each turn
# These are critical for jsonl2html
messages=html_compatible_messages, # List of strings, one per turn
scores=html_compatible_scores, # List of floats, one per turn
# Store detailed overrides per turn, matching the length of messages/scores
overrides=overrides_for_jsonl,
group_overrides={
"group_size": len(html_compatible_messages), # Effective group size is number of turns
"item_id": item_id,
"is_process_mode_full_workflow": True,
"final_score_for_workflow": final_workflow_reward, # Store the overall workflow score here
"target_sequence": workflow_state.get("target_sequence", "N/A"),
"designed_binder_sequence": workflow_state.get("designed_binder_sequence", "N/A"),
"final_plddt": workflow_state.get("af2_multimer_plddt", 0.0)
}
)
# Log completed workflow for WandB before adding to metrics
await self.add_rollouts_for_wandb(workflow_state)
# --- End of Fix for jsonl2html ---
# Log detailed workflow state to WandB (this call should use workflow_state directly)
await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy()) # Keep passing workflow_state for detailed wandb logging
self.completed_episode_metrics.append(workflow_state.copy())
if item_id in self.episodes_state: del self.episodes_state[item_id]
return process_mode_scored_data, []
@ -868,25 +1142,25 @@ class BinderBenchEnv(BaseEnv):
workflow_state["retry_count_this_internal_step"] = 0
workflow_state["previous_tool_error_message"] = None
else: # Tool failed or was incorrect
if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
workflow_state["retry_count_this_internal_step"] += 1
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Max retries reached. Terminating.")
workflow_state["workflow_complete_flag"] = True
# else: it will be added to backlog below to retry
else: # Failure at step 3 (AF2M) or other non-retryable step
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at critical/non-retryable step. Terminating.")
else: # Failure at non-retryable step (should never reach here with MAX_INTERNAL_STEPS = 4)
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at non-retryable step. Terminating.")
workflow_state["workflow_complete_flag"] = True
if workflow_state["current_internal_step"] < 4 and not workflow_state.get("workflow_complete_flag"):
# Add to backlog if:
# 1. Last tool was successful (to move to next step)
# OR
# 2. Last tool failed, current step is < 3, and we haven't hit max retries (to retry current step)
# 2. Last tool failed, current step is <= 3, and we haven't hit max retries (to retry current step)
should_add_to_backlog = False
if workflow_state["last_tool_success"]:
should_add_to_backlog = True
elif workflow_state["current_internal_step"] < 3 and \
elif workflow_state["current_internal_step"] <= 3 and \
workflow_state["retry_count_this_internal_step"] <= self.config.max_retries_per_internal_step:
should_add_to_backlog = True
@ -897,10 +1171,16 @@ class BinderBenchEnv(BaseEnv):
logger.info(f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. Internal step: {workflow_state['current_internal_step']}")
if workflow_state.get("workflow_complete_flag"): # If flag was set either by reaching step 4 or by retry logic
# Log completed workflow for WandB before adding to metrics
await self.add_rollouts_for_wandb(workflow_state)
self.completed_episode_metrics.append(workflow_state.copy())
if item_id in self.episodes_state: del self.episodes_state[item_id]
# For completed workflows in serve mode, use direct logging with workflow_state
# before it gets deleted
if item_id in self.episodes_state:
# Use direct workflow_state logging for maximum detail
await self.add_rollouts_for_wandb(data_for_log=self.episodes_state[item_id].copy())
self.completed_episode_metrics.append(self.episodes_state[item_id].copy())
del self.episodes_state[item_id]
# Note: We don't need to call add_rollouts_for_wandb with scored_data_serve here
# BaseEnv.handle_send_to_api will call it automatically with the scored_data_serve
# that we return
return scored_data_serve, backlog_items_serve
@ -968,7 +1248,7 @@ class BinderBenchEnv(BaseEnv):
logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}")
else:
detailed_scores["overall_reward"] = -0.5
detailed_scores["overall_reward"] = 0.0
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): Evaluation failed or wrong tool. Reward: -0.5. Last tool success: {last_tool_success}, Called: {called_tool_name}")
else:
@ -1046,24 +1326,62 @@ class BinderBenchEnv(BaseEnv):
# Given it's populated by collect_trajectories, clearing seems appropriate for periodic eval.
self.completed_episode_metrics.clear()
async def add_rollouts_for_wandb(self, workflow_state: Dict):
"""Adds a completed workflow summary to the wandb rollout buffer."""
async def add_rollouts_for_wandb(self,
scored_data_group: ScoredDataGroup = None, # From BaseEnv
item_id: Item = None, # From BaseEnv
data_for_log: Dict = None): # Our custom param for direct workflow_state logging
"""Adds a workflow summary to the wandb rollout buffer.
This method has two modes of operation:
1. Direct logging with workflow_state (preferred for detailed logging):
- Called from within collect_trajectories with data_for_log=workflow_state.copy()
- This provides maximum detail for logging
2. BaseEnv compatibility mode:
- Called from BaseEnv.handle_send_to_api with scored_data_group and item_id
- Used automatically by the framework
- May have less detail if workflow_state was already deleted
Args:
scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv)
item_id: The item identifier, which is the key to our episodes_state (from BaseEnv)
data_for_log: Direct workflow state to log (our custom parameter for direct logging)
"""
if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"):
# Ensure rollouts_for_wandb exists, BaseEnv usually inits it in its __init__
# but if not, init here.
# Ensure rollouts_for_wandb exists
if not hasattr(self, "rollouts_for_wandb"):
self.rollouts_for_wandb = []
# Determine the workflow state to use
workflow_state = None
# Case 1: Direct workflow state provided (most detailed)
if data_for_log is not None and isinstance(data_for_log, dict):
workflow_state = data_for_log
# Extract item_id from data_for_log if needed
if item_id is None and "item_id" in workflow_state:
item_id = workflow_state["item_id"]
# Case 2: Try to get workflow_state from episodes_state (if not already deleted)
elif item_id is not None and item_id in self.episodes_state:
workflow_state = self.episodes_state[item_id]
# Case 3: No usable state - early return with a debug log (not warning)
# This happens in BaseEnv.handle_send_to_api after workflow is already completed
if workflow_state is None:
# This is expected in BaseEnv's call after workflow_state is deleted, so use debug level
logger.debug(f"No workflow_state available for WandB logging (item_id={item_id})")
return
# Customize what you want to see in the WandB table for a completed workflow
# Handle cases where values might be None
item_id = workflow_state.get("item_id", "unknown-id")
target_seq = workflow_state.get("target_sequence", "N/A")
# Handle designed_binder which might be None
designed_binder = workflow_state.get("designed_binder_sequence", "N/A")
if designed_binder is None:
designed_binder = "N/A"
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
iptm = workflow_state.get("af2_multimer_iptm", 0.0) # Even if 0, log it
cumulative_reward = workflow_state.get("cumulative_reward", 0.0)
@ -1084,15 +1402,20 @@ class BinderBenchEnv(BaseEnv):
# Safely truncate strings
target_preview = target_seq[:30] + "..." if isinstance(target_seq, str) and len(target_seq) > 30 else target_seq
if designed_binder == "N/A" or designed_binder is None:
binder_preview = "N/A"
else:
binder_preview = designed_binder[:30] + "..." if len(str(designed_binder)) > 30 else designed_binder
# Use item_id from workflow_state if still None
if item_id is None:
item_id = workflow_state.get("item_id", "unknown-id")
# Add to rollouts buffer
self.rollouts_for_wandb.append(
( # This tuple structure will be used by create_rollout_table
item_id,
str(item_id), # Ensure item_id is a string
target_preview,
binder_preview,
f"{plddt:.2f}",