rfdiffusion fix

This commit is contained in:
based-tachikoma 2025-05-19 19:42:48 -07:00
parent 4d9bec44c6
commit de9dfff221
8 changed files with 1253 additions and 104 deletions

View file

@ -0,0 +1 @@
"""Protein design model API modules."""

View file

@ -0,0 +1,150 @@
"""AlphaFold2 API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_alphafold2(
sequence: str,
api_key: str,
algorithm: str = "mmseqs2",
e_value: float = 0.0001,
iterations: int = 1,
databases: List[str] = ["small_bfd"],
relax_prediction: bool = False,
skip_template_search: bool = True,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 600, # Increased timeout
max_retries: int = 3 # Added retry parameter
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2 API.
Args:
sequence: Protein sequence in one-letter code
api_key: NVIDIA NIM API key
algorithm: MSA search algorithm, "mmseqs2" or "jackhmmer"
e_value: E-value threshold for template search
iterations: Number of iterations for template search
databases: List of databases to search
relax_prediction: Whether to relax the prediction
skip_template_search: Whether to skip template search
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"sequence": sequence,
"algorithm": algorithm,
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction,
"skip_template_search": skip_template_search
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2 job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling AlphaFold2 API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
import traceback
logger.error(f"Error calling AlphaFold2 API: {e}")
logger.error(traceback.format_exc())
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"AlphaFold2 job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking AlphaFold2 job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling AlphaFold2 job status: {e}")
return None

View file

@ -0,0 +1,366 @@
"""AlphaFold2-Multimer API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import zipfile
import io
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
# Helper functions
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
"""
Splits a string containing concatenated PDB file contents.
Assumes models are separated by "ENDMDL" or just "END" for the last/single model.
"""
pdbs = []
current_pdb_lines = []
if not concatenated_pdb_str:
return []
for line in concatenated_pdb_str.splitlines(keepends=True):
current_pdb_lines.append(line)
if line.startswith("ENDMDL") or line.startswith("END "):
pdbs.append("".join(current_pdb_lines).strip())
current_pdb_lines = []
if current_pdb_lines:
remaining_pdb = "".join(current_pdb_lines).strip()
if remaining_pdb:
pdbs.append(remaining_pdb)
return [pdb for pdb in pdbs if pdb]
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
"""
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
Returns:
A tuple containing:
- average_plddt (float): Average pLDDT over all C-alpha atoms.
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
"""
total_plddt = 0.0
ca_atom_count = 0
plddt_scores_per_ca: List[float] = []
plddt_scores_per_chain: Dict[str, List[float]] = {}
for line in pdb_string.splitlines():
if line.startswith("ATOM"):
atom_name = line[12:16].strip()
if atom_name == "CA":
try:
plddt_value = float(line[60:66].strip())
total_plddt += plddt_value
plddt_scores_per_ca.append(plddt_value)
ca_atom_count += 1
chain_id = line[21:22].strip()
if chain_id not in plddt_scores_per_chain:
plddt_scores_per_chain[chain_id] = []
plddt_scores_per_chain[chain_id].append(plddt_value)
except ValueError:
pass
except IndexError:
pass
if ca_atom_count == 0:
return 0.0, [], {}
average_plddt = total_plddt / ca_atom_count
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_nvidia_zip_output(
zip_content: bytes,
output_prefix: str,
response_headers: Optional[Dict[str, str]] = None
) -> Optional[Dict[str, Any]]:
"""
Processes the ZIP file content from NVIDIA NIM.
- Expects a .response file with concatenated PDBs, or individual PDB files.
- Extracts PDBs and calculates pLDDT scores for each structure.
- Returns paths to saved files and calculated pLDDT scores.
"""
output_dir = Path(f"./{output_prefix}_results")
output_dir.mkdir(parents=True, exist_ok=True)
# Save the original ZIP file
zip_file_path = output_dir / f"{output_prefix}.zip"
with open(zip_file_path, 'wb') as f:
f.write(zip_content)
logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}")
# Initialize the results dictionary that call_alphafold2_multimer will return
results: Dict[str, Any] = {
"zip_file_path": str(zip_file_path), # Path to the saved original ZIP
"structures": [], # List to hold info for each PDB structure found
# We will NOT be trying to parse iptm/ptm from ranking_debug.json
"iptm_score": None, # Explicitly None, or remove if not needed at all
"ptm_score": None, # Explicitly None, or remove
# Optional: Store path to the .response file if it exists and is processed
"extracted_response_file_path": None,
}
pdb_strings_to_process = []
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
# First, check for a ".response" file with concatenated PDBs
response_file_name = None
for member_name in zf.namelist():
if member_name.lower().endswith(".response"):
response_file_name = member_name
break
if response_file_name:
logger.info(f"Found concatenated response file in ZIP: {response_file_name}")
response_data = zf.read(response_file_name)
response_content_str = response_data.decode('utf-8', errors='replace')
# Save the raw .response file
extracted_response_file_path = output_dir / Path(response_file_name).name
with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp:
f_resp.write(response_content_str)
results["extracted_response_file_path"] = str(extracted_response_file_path)
logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}")
pdb_strings_to_process.extend(_split_pdb_content(response_content_str))
else:
# If no ".response" file, look for individual PDB files
logger.info("No .response file found. Looking for individual .pdb files in ZIP.")
for member_name in zf.namelist():
if member_name.lower().endswith(".pdb"):
logger.info(f"Found individual PDB file in ZIP: {member_name}")
pdb_content_bytes = zf.read(member_name)
pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace'))
if not pdb_strings_to_process:
logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).")
return results # Return with empty structures list
logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.")
for i, pdb_str in enumerate(pdb_strings_to_process):
if not pdb_str.strip(): # Skip empty PDB strings
logger.debug(f"Skipping empty PDB string at index {i}.")
continue
structure_data: Dict[str, Any] = {
"model_index": i, # 0-indexed based on order found
"pdb_content": pdb_str # Store the raw PDB string
}
# Calculate pLDDT scores using your existing function
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
structure_data["average_plddt"] = avg_plddt
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs
structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT]
# Calculate average pLDDT for each chain (already in your previous code, good to keep)
avg_plddt_per_chain = {}
for chain_id, chain_plddts in plddts_by_chain.items():
if chain_plddts: # Avoid division by zero
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
else:
avg_plddt_per_chain[chain_id] = 0.0
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
# Save the individual PDB string to a file
pdb_file_name_stem = Path(output_prefix).stem
# Suffix for rank if multiple models found, otherwise simpler name
rank_suffix = f"_model_{i}" # Consistent naming for multiple models
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
try:
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
f_pdb.write(pdb_str)
structure_data["saved_pdb_path"] = str(pdb_file_path)
logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
except Exception as e_write:
logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}")
structure_data["saved_pdb_path"] = None
results["structures"].append(structure_data)
if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.")
except zipfile.BadZipFile:
logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.")
# results dictionary is already initialized, will be returned as is, potentially empty.
except Exception as e:
logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True)
# As above, return results which might be partially filled or empty.
return results
async def call_alphafold2_multimer(
sequences: List[str],
api_key: str,
algorithm: str = "jackhmmer",
e_value: float = 0.0001,
iterations: int = 1,
databases: List[str] = ["uniref90", "small_bfd", "mgnify"],
relax_prediction: bool = True,
selected_models: Optional[List[int]] = None,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30,
timeout: int = 3600
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2-Multimer API.
Returns a dictionary structured by _process_nvidia_zip_output.
"""
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
data: Dict[str, Any] = {
"sequences": sequences,
"algorithm": algorithm,
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction
}
if selected_models is not None:
data["selected_models"] = selected_models
logger.info(f"Using selected_models: {selected_models}")
try:
initial_post_timeout = min(timeout, 600)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=initial_post_timeout
) as response:
if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.")
content = await response.read()
return await _process_nvidia_zip_output(
zip_content=content,
output_prefix="alphafold2_multimer_sync_output",
response_headers=response.headers
)
elif response.status == 202:
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
overall_timeout=timeout
)
else:
logger.error("No request ID in 202 response headers")
return None
else: # Handle other error statuses from POST
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except asyncio.TimeoutError:
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
return None
except Exception as e:
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 30,
overall_timeout: int = 3600
) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time()
# Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it,
# or a fixed reasonably long duration.
# The NVCF-POLL-SECONDS in the POST header is 300s.
# The GET request to /status should also ideally respect a similar long-poll duration from the server.
status_check_timeout = 330 # seconds (e.g., 5.5 minutes)
logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
while True:
current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time
if elapsed_time > overall_timeout:
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
return None
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout)
if current_status_check_timeout <= 0:
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
return None
try:
async with aiohttp.ClientSession() as session:
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=current_status_check_timeout
) as response:
if response.status == 200:
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
zip_content_bytes = await response.read()
return await _process_nvidia_zip_output(
zip_content=zip_content_bytes,
output_prefix=f"alphafold2_multimer_output_{req_id}",
response_headers=response.headers
)
elif response.status == 202:
try:
job_status_json = await response.json()
percent_complete = job_status_json.get('percentComplete', 'N/A')
status_message = job_status_json.get('status', 'running')
logger.debug(
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
)
except (aiohttp.ContentTypeError, json.JSONDecodeError):
logger.debug(
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
)
await asyncio.sleep(polling_interval)
else: # Handle other error statuses from status GET
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
# Log the error, but continue polling unless it's a fatal client error (4xx other than 429)
# or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body)
# For a 504 like you saw, we might want to retry a few times then give up.
# For now, this will return None on non-200/202, which your test script will catch.
return None
except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
await asyncio.sleep(polling_interval)
except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
await asyncio.sleep(polling_interval)
except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
return None

View file

@ -0,0 +1,138 @@
"""ProteinMPNN API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_proteinmpnn(
input_pdb: str,
api_key: str,
ca_only: bool = False,
use_soluble_model: bool = False,
sampling_temp: List[float] = [0.1],
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM ProteinMPNN API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
ca_only: Whether to use only Cα atoms
use_soluble_model: Whether to use the soluble model
sampling_temp: List of sampling temperatures
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"ca_only": ca_only,
"use_soluble_model": use_soluble_model,
"sampling_temp": sampling_temp
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"ProteinMPNN job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling ProteinMPNN API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error calling ProteinMPNN API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"ProteinMPNN job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking ProteinMPNN job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling ProteinMPNN job status: {e}")
return None

View file

@ -0,0 +1,142 @@
"""RFDiffusion API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_rfdiffusion(
input_pdb: str,
api_key: str,
contigs: str = None,
hotspot_res: List[str] = None,
diffusion_steps: int = 15,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM RFDiffusion API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
contigs: Contig string (e.g. "A20-60/0 50-100")
hotspot_res: List of hotspot residues (e.g. ["A50","A51"])
diffusion_steps: Number of diffusion steps
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"diffusion_steps": diffusion_steps
}
# Add optional parameters if provided
if contigs:
data["contigs"] = contigs
if hotspot_res:
data["hotspot_res"] = hotspot_res
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"RFDiffusion job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling RFDiffusion API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error calling RFDiffusion API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"RFDiffusion job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking RFDiffusion job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling RFDiffusion job status: {e}")
return None

View file

@ -3,9 +3,10 @@ import json
import logging
import os
import random
import re
import uuid
from pathlib import Path
from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict
from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict, Set
import yaml
import wandb # Add import for wandb
@ -89,6 +90,85 @@ def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str
return ds
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
"""
Parses PDB content to extract detailed information for each chain.
Returns:
A tuple containing:
- chain_details (Dict[str, Dict[str, int]]):
A dictionary where keys are chain IDs (e.g., "A").
Each value is another dictionary:
{
"min_residue": int, # Smallest residue number found for this chain
"max_residue": int, # Largest residue number found for this chain
"length": int # Count of unique C-alpha atoms (residues) in this chain
}
- pdb_preview (str): A string preview of the PDB content.
"""
chain_info_temp: Dict[str, Dict[str, Union[Set[int], int]]] = {} # Stores residue numbers and CA count for each chain
atom_lines = []
header_lines = []
# First pass: Collect all residue numbers and CA atoms per chain
for line in pdb_content.splitlines():
if line.startswith("ATOM"): # Consider only ATOM records for canonical residues
atom_lines.append(line)
chain_id = line[21:22].strip()
if not chain_id:
chain_id = " " # Default for blank chain ID, consider how RFDiffusion handles this
atom_name = line[12:16].strip()
try:
residue_num = int(line[22:26].strip())
if chain_id not in chain_info_temp:
chain_info_temp[chain_id] = {"residues": set(), "ca_count": 0}
chain_info_temp[chain_id]["residues"].add(residue_num)
if atom_name == "CA":
chain_info_temp[chain_id]["ca_count"] += 1
except ValueError:
logger.warning(f"Could not parse residue number from PDB line: {line}")
continue
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
header_lines.append(line)
# Second pass: Calculate min, max, and length from collected data
chain_details: Dict[str, Dict[str, int]] = {}
for chain_id, data in chain_info_temp.items():
if data["residues"]: # Only process if residues were found
min_res = min(data["residues"])
max_res = max(data["residues"])
# Length can be defined in two ways:
# 1. max_res - min_res + 1 (if contiguous numbering)
# 2. Count of unique residues (safer for gaps, but AF2 is usually contiguous)
# 3. Count of C-alpha atoms (good proxy for actual modeled residues)
# Let's use ca_count as it reflects actual modeled residues.
# If ca_count is 0 but residues were found (e.g. only HETATMs), this needs thought.
# For now, prioritizing ca_count.
length = data["ca_count"] if data["ca_count"] > 0 else len(data["residues"])
chain_details[chain_id] = {
"min_residue": min_res,
"max_residue": max_res,
"length": length
}
else:
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
# Construct PDB preview
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
remaining_preview_lines = preview_lines - len(preview_str_parts)
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
pdb_preview = "\n".join(preview_str_parts)
if len(pdb_content.splitlines()) > preview_lines:
pdb_preview += "\n..."
return chain_details, pdb_preview
def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, int], str]:
chain_lengths = {}
current_chain_id = None
@ -130,11 +210,10 @@ def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10)
return chain_lengths, pdb_preview
def construct_user_prompt(state: dict) -> str: # state is an item from self.episodes_state
internal_step = state.get("current_internal_step", 0) # Use internal step from our state
internal_step = state.get("current_internal_step", 0)
target_sequence = state.get("target_sequence")
user_prompt_str = ""
# Base prompt construction (your existing logic)
if internal_step == 0: # Step 1: Predict Target Structure (AlphaFold2)
user_prompt_str = (
f"The target protein sequence is: {target_sequence}. "
@ -142,52 +221,85 @@ def construct_user_prompt(state: dict) -> str: # state is an item from self.epis
"You must provide the 'sequence' argument."
)
elif internal_step == 1: # Step 2: Design Binder Backbone (RFDiffusion)
# Use the stored preview and chain info
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
chain_info = state.get("target_chain_info", {})
chain_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in chain_info.items()])
if not chain_info_str: chain_info_str = "Chain information not available."
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") # Can keep preview for general context
# --- NEW CHAIN INFO FORMATTING ---
chain_details = state.get("target_chain_details", {}) # Get the new detailed info
if chain_details:
chain_info_parts = []
for chain_id, details in chain_details.items():
min_r = details.get('min_residue', 'N/A')
max_r = details.get('max_residue', 'N/A')
l = details.get('length', 'N/A')
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
chain_info_str = "\n- ".join(chain_info_parts)
if chain_info_str:
chain_info_str = "- " + chain_info_str # Add leading bullet for the first item
else:
chain_info_str = "Chain information not available or PDB not yet processed."
# --- END NEW CHAIN INFO FORMATTING ---
user_prompt_str = (
f"The 3D structure of the target protein has been predicted. Target PDB preview:\n{target_pdb_preview}\n"
f"Target chain information: {chain_info_str}.\n"
"Now, design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You need to specify 'contigs'. Contigs define segments from the target PDB (e.g., 'A1-100' means residues 1-100 of target chain A) "
"and segments for the new binder (e.g., '/0 50-70' means generate a new chain of length 50 to 70 residues). "
"A full example for a 60-residue binder for Chain A of the target (if Chain A has 100 residues): 'A1-100/0 60'. "
"Ensure any target residue numbers in contigs are within the valid range for the respective chain. "
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring they exist on the target."
f"The 3D structure of the target protein has been predicted.\n"
# Optional: f"Target PDB preview:\n{target_pdb_preview}\n\n"
f"Target Protein Chain Details:\n{chain_info_str}\n\n" # Use the detailed chain info
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
"Examples:\n"
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
)
elif internal_step == 2: # Step 3: Design Binder Sequence (ProteinMPNN)
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.")
binder_chain_info = state.get("binder_chain_info", {}) # Info about the binder backbone itself
binder_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in binder_chain_info.items()])
if not binder_info_str: binder_info_str = "Binder chain information not available."
# Get detailed binder chain information using the get_pdb_chain_details function
binder_pdb_content = state.get("binder_backbone_pdb_content")
if binder_pdb_content:
binder_chain_details, binder_pdb_preview = get_pdb_chain_details(binder_pdb_content)
binder_chain_info_str = "\n- ".join([f"Chain {cID} (Residues: {d.get('min_residue','N/A')}-{d.get('max_residue','N/A')}, Length: {d.get('length','N/A')})" for cID, d in binder_chain_details.items()])
if binder_chain_info_str: binder_chain_info_str = "- " + binder_chain_info_str
else:
binder_pdb_preview = "Binder PDB preview not available."
binder_chain_info_str = "Binder chain information not available."
user_prompt_str = (
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
f"Binder chain information: {binder_info_str}.\n"
f"Binder chain information:\n{binder_chain_info_str}.\n"
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
)
elif internal_step == 3: # Step 4: Evaluate Complex (AlphaFold2-Multimer)
designed_binder_seq = state.get("designed_binder_sequence", "Not yet available")
designed_binder_seq_data = state.get("designed_binder_sequence") # This is List[str]
binder_display_str = "Not available"
if isinstance(designed_binder_seq_data, list) and designed_binder_seq_data:
if len(designed_binder_seq_data) == 1:
binder_display_str = designed_binder_seq_data[0]
else:
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
for i, s in enumerate(designed_binder_seq_data)])
elif isinstance(designed_binder_seq_data, str): # Should not happen with new PMPNN parsing
binder_display_str = designed_binder_seq_data
user_prompt_str = (
f"A binder sequence has been designed: {designed_binder_seq}. "
f"The original target sequence was: {target_sequence}.\n" # Remind LLM of original target
"Finally, evaluate the binding complex of the original target protein and this designed binder using the "
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
f"The original target sequence was: {target_sequence[:60]}...\n"
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
"'evaluate_binder_complex_alphafold2_multimer' tool. "
"You can optionally specify 'relax_prediction' (default is True)."
)
else: # Workflow complete or error
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
# ***** ADD RETRY PREFIX IF APPLICABLE *****
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: # For all steps
# Retry logic should remain the same:
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
retry_prefix = "Your previous attempt at this step was not successful. "
if state.get("previous_tool_error_message"):
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
retry_prefix += "Please review the requirements and try again to correctly use the expected tool.\n\n"
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
user_prompt_str = retry_prefix + user_prompt_str
return user_prompt_str
@ -204,7 +316,7 @@ class BinderBenchConfig(BaseEnvConfig):
nim_api_base_url: str = Field("https://health.api.nvidia.com/v1", description="NIM API base URL")
api_timeout: int = Field(1800, description="Timeout for NIM API calls") # Increased default
polling_interval: int = Field(30, description="Polling interval for NIM jobs") # Increased default
output_dir: str = Field("binder_outputs", description="Directory to save PDBs, etc.")
output_dir: str = Field(default=str(Path(__file__).parent / "outputs"), description="Directory to save PDBs, etc.")
debug_protein_design_calls: bool = Field(False, description="Enable debug mode for NIM protein API calls, returning mock data.")
max_retries_per_internal_step: int = Field(100, description="Max retries for a failed tool call within a workflow step (0 means no retries).") # Default to 1 retry (2 attempts total)
# Dataset specific
@ -349,8 +461,8 @@ class BinderBenchEnv(BaseEnv):
# Create a dummy PDB content if file not found to prevent downstream errors, but log severe warning
pdb_content = "HEADER DUMMY PDB FOR DEBUG - TARGET.PDB NOT FOUND\nATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\nTER\nEND\n"
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True # Mark as "predicted" for workflow to proceed
return {"success": False, "error": f"Debug mode error: {fixed_pdb_path} not found.", "target_pdb_preview": pdb_preview}
@ -360,15 +472,15 @@ class BinderBenchEnv(BaseEnv):
pdb_content = f.read()
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True
# Optionally, save a copy of this mock PDB to the usual output location for consistency
debug_output_pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2_DEBUG.pdb"
with open(debug_output_pdb_path, "w") as f: f.write(pdb_content)
logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain info: {chain_lengths}")
logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain details: {chain_details}")
return {"success": True, "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "target_pdb_preview": pdb_preview}
except Exception as e:
@ -393,13 +505,13 @@ class BinderBenchEnv(BaseEnv):
if api_result and isinstance(api_result, list) and api_result[0]:
pdb_content = api_result[0]
workflow_state["target_pdb_content"] = pdb_content
chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content)
workflow_state["target_chain_info"] = chain_lengths
chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function
workflow_state["target_chain_details"] = chain_details # Store detailed info
workflow_state["target_pdb_preview"] = pdb_preview
workflow_state["target_structure_predicted"] = True
pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2.pdb"
with open(pdb_path, "w") as f: f.write(pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain info: {chain_lengths}")
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
return {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview}
else:
logger.error(f"Workflow {item_id}: AlphaFold2 call failed or returned unexpected data: {api_result}")
@ -434,47 +546,125 @@ class BinderBenchEnv(BaseEnv):
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
binder_pdb = workflow_state.get("binder_backbone_pdb_content")
if not binder_pdb: return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
if not binder_pdb:
return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
sampling_temp = args.get("sampling_temp", [0.1]) # Default if not provided
sampling_temp_list = args.get("sampling_temp", [0.1])
item_id = workflow_state["item_id"]
api_result = await call_proteinmpnn(
input_pdb=binder_pdb, api_key=self.config.nim_api_key,
sampling_temp=sampling_temp,
sampling_temp=sampling_temp_list,
timeout=self.config.api_timeout, polling_interval=self.config.polling_interval
# Add other PMPNN specific params
)
if api_result and "mfasta" in api_result:
fasta_content = api_result["mfasta"]
designed_sequence = ""
for line in fasta_content.splitlines():
if not line.startswith(">") and line.strip():
designed_sequence = line.strip()
break
if not designed_sequence:
return {"success": False, "error": "Could not parse sequence from ProteinMPNN FASTA."}
workflow_state["designed_binder_sequence"] = designed_sequence
workflow_state["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta"
with open(fasta_path, "w") as f: f.write(fasta_content)
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}")
# NO LONGER INCREMENT current_internal_step HERE - collect_trajectories will handle this
return {"success": True, "message": "ProteinMPNN complete.", "designed_binder_sequence": designed_sequence}
else:
if not (api_result and "mfasta" in api_result):
logger.error(f"Workflow {item_id}: ProteinMPNN call failed or returned unexpected data: {api_result}")
return {"success": False, "error": "ProteinMPNN failed."}
return {"success": False, "error": "ProteinMPNN call failed or no mfasta in result."}
fasta_content = api_result["mfasta"]
logger.info(f"CRITICAL_DEBUG: ProteinMPNN raw mfasta output for item {item_id}:\n{fasta_content}")
# --- FASTA Parsing Logic to find best sequence by global_score ---
entries: List[Tuple[float, str, str]] = [] # (global_score, header, sequence_line)
current_header = None
current_sequence_parts: List[str] = []
for line_content in fasta_content.splitlines():
line = line_content.strip()
if not line: continue
if line.startswith(">"):
if current_header and current_sequence_parts: # Process previous entry
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
entries.append((global_score, current_header, full_sequence_line))
current_header = line
current_sequence_parts = []
else:
current_sequence_parts.append(line)
if current_header and current_sequence_parts: # Process the last entry
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
entries.append((global_score, current_header, full_sequence_line))
if not entries:
logger.error(f"Workflow {item_id}: No sequences found in ProteinMPNN mfasta output.")
return {"success": False, "error": "No sequences parsed from ProteinMPNN mfasta."}
# Sort by global_score (descending) and select the best
entries.sort(key=lambda x: x[0], reverse=True)
best_global_score, best_header, best_full_sequence_line = entries[0]
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}'")
logger.info(f"Workflow {item_id}: Corresponding sequence line: '{best_full_sequence_line}'")
# Split the selected sequence line by '/' to handle potential chainbreaks
parsed_binder_chains: List[str] = [
seq_part.strip() for seq_part in best_full_sequence_line.split('/') if seq_part.strip()
]
if not parsed_binder_chains:
error_msg = f"Splitting best PMPNN sequence ('{best_full_sequence_line}') by '/' yielded no valid chains."
logger.error(f"Workflow {item_id}: {error_msg}")
return {"success": False, "error": error_msg}
# Validate each parsed chain (ensure they are valid protein sequences)
for seq_idx, seq_part in enumerate(parsed_binder_chains):
if not (seq_part and seq_part.isalpha() and seq_part.isupper()):
error_msg = f"Parsed binder chain {seq_idx+1} ('{seq_part[:30]}...') contains invalid characters or is empty."
logger.error(f"Workflow {item_id}: {error_msg}")
return {"success": False, "error": error_msg}
workflow_state["designed_binder_sequence"] = parsed_binder_chains # Store as List[str]
workflow_state["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta"
with open(fasta_path, "w") as f: f.write(fasta_content) # Save original full FASTA
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
preview = parsed_binder_chains[0][:60] + "..." if len(parsed_binder_chains[0]) > 60 else parsed_binder_chains[0]
if len(parsed_binder_chains) > 1:
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
return {
"success": True,
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
"designed_binder_sequence_list": parsed_binder_chains,
"designed_binder_sequence_preview": preview
}
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
target_seq = workflow_state.get("target_sequence")
binder_seq = workflow_state.get("designed_binder_sequence")
if not target_seq or not binder_seq:
return {"success": False, "error": "Missing target or binder sequence for AlphaFold2-Multimer."}
binder_seq_data = workflow_state.get("designed_binder_sequence")
if not target_seq:
return {"success": False, "error": "Missing target sequence for AlphaFold2-Multimer."}
if not binder_seq_data:
return {"success": False, "error": "Missing binder sequence for AlphaFold2-Multimer."}
# Handle binder_seq_data which could now be either a List[str] or a single string (for backward compatibility)
binder_sequences = []
if isinstance(binder_seq_data, list):
binder_sequences = binder_seq_data
elif isinstance(binder_seq_data, str):
binder_sequences = [binder_seq_data] # Wrap in list
else:
return {"success": False, "error": f"Unexpected type for binder sequence: {type(binder_seq_data)}"}
if not binder_sequences:
return {"success": False, "error": "Empty binder sequence list for AlphaFold2-Multimer."}
relax = args.get("relax_prediction", True)
item_id = workflow_state["item_id"]
logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and binder (len {len(binder_seq)}).")
# Log all sequences for debugging
total_binder_length = sum(len(seq) for seq in binder_sequences)
logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and {len(binder_sequences)} binder chain(s) (total len {total_binder_length}).")
# Check if in debug mode
if self.config.debug_protein_design_calls:
@ -515,12 +705,18 @@ class BinderBenchEnv(BaseEnv):
}
# Non-debug mode: proceed with actual API call
# Create a list with target sequence as first element, followed by all binder sequences
all_sequences = [target_seq] + binder_sequences
logger.info(f"Workflow {item_id}: Calling AlphaFold2-Multimer with {len(all_sequences)} sequences: "
f"1 target ({len(target_seq)} aa) + {len(binder_sequences)} binder chains.")
api_result = await call_alphafold2_multimer(
sequences=[target_seq, binder_seq],
sequences=all_sequences, # Pass all sequences as a flat list: [target_seq, binder_seq1, binder_seq2, ...]
api_key=self.config.nim_api_key,
relax_prediction=relax,
timeout=self.config.api_timeout, # Pass configured timeout
polling_interval=self.config.polling_interval # Pass configured polling interval
timeout=self.config.api_timeout,
polling_interval=self.config.polling_interval
)
if api_result and api_result.get("structures") and len(api_result["structures"]) > 0:
@ -624,6 +820,7 @@ class BinderBenchEnv(BaseEnv):
"target_sequence": target_sequence,
"ground_truth_binder_sequence": ground_truth_binder, # Store for final evaluation
"target_pdb_content": None,
"target_chain_details": None, # Store detailed chain information
"binder_backbone_pdb_content": None,
"designed_binder_sequence": None,
"complex_pdb_content_path": None, # Path to AF2-Multimer output
@ -758,7 +955,7 @@ class BinderBenchEnv(BaseEnv):
workflow_state["retry_count_this_internal_step"] = 0 # Reset for new step
workflow_state["previous_tool_error_message"] = None
else: # Tool call failed or was incorrect
if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
workflow_state["retry_count_this_internal_step"] += 1
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Max retries ({self.config.max_retries_per_internal_step}) reached. Terminating workflow for this item.")
@ -767,8 +964,8 @@ class BinderBenchEnv(BaseEnv):
else:
logger.info(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failed, attempt {workflow_state['retry_count_this_internal_step']}. Retrying same step.")
# Loop continues, construct_user_prompt will use retry info
else: # Failure at step 3 (AF2M) is terminal for the workflow
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at critical AF2-Multimer step. Terminating workflow.")
else: # Should never reach here with MAX_INTERNAL_STEPS = 4, but keeping for safety
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at non-retryable step. Terminating workflow.")
workflow_state["workflow_complete_flag"] = True
break # Exit the internal while loop
@ -778,24 +975,101 @@ class BinderBenchEnv(BaseEnv):
# No break here, loop condition will handle it
# After the internal while loop (for process mode)
if not all_turns_data_for_jsonl: return None, []
last_turn_data = all_turns_data_for_jsonl[-1]
aggregated_messages = [turn_data["messages_this_turn"] for turn_data in all_turns_data_for_jsonl]
aggregated_overrides = [turn_data["overrides_this_turn"] for turn_data in all_turns_data_for_jsonl]
final_reward_for_group = workflow_state.get("cumulative_reward", 0.0)
if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"):
final_reward_for_group = last_turn_data["overrides_this_turn"].get("overall_reward", 0.0)
if not all_turns_data_for_jsonl:
logger.warning(f"Workflow {item_id} in process mode: No turn data collected.")
return None, []
# --- Start of Fix for jsonl2html ---
html_compatible_messages: List[str] = []
html_compatible_scores: List[float] = []
# `overrides_for_jsonl` will store the detailed scoring dict for each turn,
# matching the structure of `html_compatible_messages` and `html_compatible_scores`.
overrides_for_jsonl: List[Dict[str, Any]] = []
for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl):
# Format messages for this turn into a single readable string
turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"]
if turn_data.get("messages_this_turn"):
for msg_obj in turn_data["messages_this_turn"]:
content_str = str(msg_obj.get("content", "[No Content]"))
if msg_obj.get("tool_calls"):
try:
tool_calls_str = json.dumps(msg_obj.get("tool_calls"), indent=2)
content_str += f"\nTool Calls:\n{tool_calls_str}"
except TypeError: # Handle non-serializable content if any
content_str += f"\nTool Calls: [Error serializing tool_calls]"
turn_str_parts.append(f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}")
else:
turn_str_parts.append("No messages recorded for this turn.")
html_compatible_messages.append("\n\n".join(turn_str_parts))
# Get the score for this specific turn
turn_score = turn_data.get("overrides_this_turn", {}).get("overall_reward", 0.0)
html_compatible_scores.append(turn_score)
# Add the detailed scoring dictionary for this turn
overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {}))
final_workflow_reward = workflow_state.get("cumulative_reward", 0.0)
# If the complex was evaluated successfully, the last turn's reward is the final one.
if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"):
final_workflow_reward = all_turns_data_for_jsonl[-1].get("overrides_this_turn", {}).get("overall_reward", 0.0)
# For the ScoredDataGroup that will be handled by BaseEnv
# We need tokens and masks for each "message" (turn) if we want BaseEnv to consider it valid
# For simplicity, we can just repeat the last turn's tokens/masks, or use placeholders
# if actual per-turn tokens aren't critical for the JSONL's main purpose (which is visualization via messages/scores).
# Let's create placeholder tokens/masks if full history isn't needed by the trainer for process_mode.
# Or, better, store actual tokens for each turn if available.
all_tokens_per_turn = [turn_data["tokens_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("tokens_this_turn")]
all_masks_per_turn = [turn_data["masks_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("masks_this_turn")]
# Ensure all_tokens_per_turn and all_masks_per_turn have same length as html_compatible_messages
# If some turns didn't produce tokens (e.g. error), we might need to pad or handle.
# For now, assuming all_turns_data_for_jsonl consistently has tokens/masks for each entry that contributes to html_compatible_messages.
if len(all_tokens_per_turn) != len(html_compatible_messages):
logger.error(f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic.")
# Fallback: repeat last turn's tokens if necessary, though this isn't ideal.
if all_turns_data_for_jsonl and all_tokens_per_turn:
last_tokens = all_tokens_per_turn[-1]
last_masks = all_masks_per_turn[-1]
all_tokens_per_turn = [last_tokens] * len(html_compatible_messages)
all_masks_per_turn = [last_masks] * len(html_compatible_messages)
else: # No token data at all
all_tokens_per_turn = [[] for _ in html_compatible_messages]
all_masks_per_turn = [[] for _ in html_compatible_messages]
# This is the ScoredDataGroup that will be written to JSONL by BaseEnv
process_mode_scored_data = ScoredDataGroup(
tokens=[last_turn_data["tokens_this_turn"]],
masks=[last_turn_data["masks_this_turn"]],
scores=[final_reward_for_group],
messages=aggregated_messages if self.config.include_messages else None,
overrides=aggregated_overrides,
group_overrides={"group_size": 1} # ***** THIS IS THE DEFINITIVE FIX FOR THIS PART *****
tokens=all_tokens_per_turn, # List of token lists, one for each turn
masks=all_masks_per_turn, # List of mask lists, one for each turn
# These are critical for jsonl2html
messages=html_compatible_messages, # List of strings, one per turn
scores=html_compatible_scores, # List of floats, one per turn
# Store detailed overrides per turn, matching the length of messages/scores
overrides=overrides_for_jsonl,
group_overrides={
"group_size": len(html_compatible_messages), # Effective group size is number of turns
"item_id": item_id,
"is_process_mode_full_workflow": True,
"final_score_for_workflow": final_workflow_reward, # Store the overall workflow score here
"target_sequence": workflow_state.get("target_sequence", "N/A"),
"designed_binder_sequence": workflow_state.get("designed_binder_sequence", "N/A"),
"final_plddt": workflow_state.get("af2_multimer_plddt", 0.0)
}
)
# Log completed workflow for WandB before adding to metrics
await self.add_rollouts_for_wandb(workflow_state)
# --- End of Fix for jsonl2html ---
# Log detailed workflow state to WandB (this call should use workflow_state directly)
await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy()) # Keep passing workflow_state for detailed wandb logging
self.completed_episode_metrics.append(workflow_state.copy())
if item_id in self.episodes_state: del self.episodes_state[item_id]
return process_mode_scored_data, []
@ -868,25 +1142,25 @@ class BinderBenchEnv(BaseEnv):
workflow_state["retry_count_this_internal_step"] = 0
workflow_state["previous_tool_error_message"] = None
else: # Tool failed or was incorrect
if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2
if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3
workflow_state["retry_count_this_internal_step"] += 1
if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step:
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Max retries reached. Terminating.")
workflow_state["workflow_complete_flag"] = True
# else: it will be added to backlog below to retry
else: # Failure at step 3 (AF2M) or other non-retryable step
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at critical/non-retryable step. Terminating.")
else: # Failure at non-retryable step (should never reach here with MAX_INTERNAL_STEPS = 4)
logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at non-retryable step. Terminating.")
workflow_state["workflow_complete_flag"] = True
if workflow_state["current_internal_step"] < 4 and not workflow_state.get("workflow_complete_flag"):
# Add to backlog if:
# 1. Last tool was successful (to move to next step)
# OR
# 2. Last tool failed, current step is < 3, and we haven't hit max retries (to retry current step)
# 2. Last tool failed, current step is <= 3, and we haven't hit max retries (to retry current step)
should_add_to_backlog = False
if workflow_state["last_tool_success"]:
should_add_to_backlog = True
elif workflow_state["current_internal_step"] < 3 and \
elif workflow_state["current_internal_step"] <= 3 and \
workflow_state["retry_count_this_internal_step"] <= self.config.max_retries_per_internal_step:
should_add_to_backlog = True
@ -897,10 +1171,16 @@ class BinderBenchEnv(BaseEnv):
logger.info(f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. Internal step: {workflow_state['current_internal_step']}")
if workflow_state.get("workflow_complete_flag"): # If flag was set either by reaching step 4 or by retry logic
# Log completed workflow for WandB before adding to metrics
await self.add_rollouts_for_wandb(workflow_state)
self.completed_episode_metrics.append(workflow_state.copy())
if item_id in self.episodes_state: del self.episodes_state[item_id]
# For completed workflows in serve mode, use direct logging with workflow_state
# before it gets deleted
if item_id in self.episodes_state:
# Use direct workflow_state logging for maximum detail
await self.add_rollouts_for_wandb(data_for_log=self.episodes_state[item_id].copy())
self.completed_episode_metrics.append(self.episodes_state[item_id].copy())
del self.episodes_state[item_id]
# Note: We don't need to call add_rollouts_for_wandb with scored_data_serve here
# BaseEnv.handle_send_to_api will call it automatically with the scored_data_serve
# that we return
return scored_data_serve, backlog_items_serve
@ -968,7 +1248,7 @@ class BinderBenchEnv(BaseEnv):
logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}")
else:
detailed_scores["overall_reward"] = -0.5
detailed_scores["overall_reward"] = 0.0
logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): Evaluation failed or wrong tool. Reward: -0.5. Last tool success: {last_tool_success}, Called: {called_tool_name}")
else:
@ -1046,24 +1326,62 @@ class BinderBenchEnv(BaseEnv):
# Given it's populated by collect_trajectories, clearing seems appropriate for periodic eval.
self.completed_episode_metrics.clear()
async def add_rollouts_for_wandb(self, workflow_state: Dict):
"""Adds a completed workflow summary to the wandb rollout buffer."""
async def add_rollouts_for_wandb(self,
scored_data_group: ScoredDataGroup = None, # From BaseEnv
item_id: Item = None, # From BaseEnv
data_for_log: Dict = None): # Our custom param for direct workflow_state logging
"""Adds a workflow summary to the wandb rollout buffer.
This method has two modes of operation:
1. Direct logging with workflow_state (preferred for detailed logging):
- Called from within collect_trajectories with data_for_log=workflow_state.copy()
- This provides maximum detail for logging
2. BaseEnv compatibility mode:
- Called from BaseEnv.handle_send_to_api with scored_data_group and item_id
- Used automatically by the framework
- May have less detail if workflow_state was already deleted
Args:
scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv)
item_id: The item identifier, which is the key to our episodes_state (from BaseEnv)
data_for_log: Direct workflow state to log (our custom parameter for direct logging)
"""
if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"):
# Ensure rollouts_for_wandb exists, BaseEnv usually inits it in its __init__
# but if not, init here.
# Ensure rollouts_for_wandb exists
if not hasattr(self, "rollouts_for_wandb"):
self.rollouts_for_wandb = []
# Determine the workflow state to use
workflow_state = None
# Case 1: Direct workflow state provided (most detailed)
if data_for_log is not None and isinstance(data_for_log, dict):
workflow_state = data_for_log
# Extract item_id from data_for_log if needed
if item_id is None and "item_id" in workflow_state:
item_id = workflow_state["item_id"]
# Case 2: Try to get workflow_state from episodes_state (if not already deleted)
elif item_id is not None and item_id in self.episodes_state:
workflow_state = self.episodes_state[item_id]
# Case 3: No usable state - early return with a debug log (not warning)
# This happens in BaseEnv.handle_send_to_api after workflow is already completed
if workflow_state is None:
# This is expected in BaseEnv's call after workflow_state is deleted, so use debug level
logger.debug(f"No workflow_state available for WandB logging (item_id={item_id})")
return
# Customize what you want to see in the WandB table for a completed workflow
# Handle cases where values might be None
item_id = workflow_state.get("item_id", "unknown-id")
target_seq = workflow_state.get("target_sequence", "N/A")
# Handle designed_binder which might be None
designed_binder = workflow_state.get("designed_binder_sequence", "N/A")
if designed_binder is None:
designed_binder = "N/A"
plddt = workflow_state.get("af2_multimer_plddt", 0.0)
iptm = workflow_state.get("af2_multimer_iptm", 0.0) # Even if 0, log it
cumulative_reward = workflow_state.get("cumulative_reward", 0.0)
@ -1084,15 +1402,20 @@ class BinderBenchEnv(BaseEnv):
# Safely truncate strings
target_preview = target_seq[:30] + "..." if isinstance(target_seq, str) and len(target_seq) > 30 else target_seq
if designed_binder == "N/A" or designed_binder is None:
binder_preview = "N/A"
else:
binder_preview = designed_binder[:30] + "..." if len(str(designed_binder)) > 30 else designed_binder
# Use item_id from workflow_state if still None
if item_id is None:
item_id = workflow_state.get("item_id", "unknown-id")
# Add to rollouts buffer
self.rollouts_for_wandb.append(
( # This tuple structure will be used by create_rollout_table
item_id,
str(item_id), # Ensure item_id is a string
target_preview,
binder_preview,
f"{plddt:.2f}",

View file

@ -0,0 +1 @@
"""Utility functions for the protein design environment."""

View file

@ -0,0 +1,28 @@
"""API utility functions for the protein design environment."""
import os
import logging
import yaml
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
logger = logging.getLogger(__name__)
def load_api_key() -> Optional[str]:
"""
Load the NVIDIA NIM API key from environment variables.
Returns:
The API key from environment variables, or None if not found
"""
api_key = os.environ.get("NVIDIA_NIM_API_KEY")
if not api_key:
logger.error("NVIDIA_NIM_API_KEY not found in environment variables. "
"Please set it in your .env file.")
return None
return api_key