mirror of
https://github.com/NousResearch/atropos.git
synced 2026-05-01 17:45:16 +00:00
refactor, full run
This commit is contained in:
parent
de9dfff221
commit
1ee67de035
12 changed files with 1039 additions and 1127 deletions
|
|
@ -1,33 +1,19 @@
|
|||
# NVIDIA NIM Environment Default Configuration for BinderBench
|
||||
|
||||
# Debug Mode - set to true to use mock data instead of actual API calls
|
||||
debug_protein_design_calls: false
|
||||
|
||||
# Retry settings for failed steps
|
||||
max_retries_per_internal_step: 100 # Increased to allow many retries for tool calls
|
||||
max_retries_per_internal_step: 100
|
||||
|
||||
# API Settings
|
||||
# nim_api_key is loaded from .env file using NVIDIA_NIM_API_KEY
|
||||
nim_api_base_url: "https://health.api.nvidia.com/v1"
|
||||
api_timeout: 600
|
||||
polling_interval: 10
|
||||
|
||||
# Protein Design Settings
|
||||
output_dir: "environments/hack0/protein_design_env/outputs"
|
||||
|
||||
# WandB tracking settings
|
||||
use_wandb: true
|
||||
wandb_name: "binderbench"
|
||||
wandb_project: "atropos" # Will default to this if not specified
|
||||
include_messages: true # Include messages in WandB logs
|
||||
wandb_project: "atropos"
|
||||
include_messages: true
|
||||
|
||||
# Dataset configuration
|
||||
dataset_name: "ronig/protein_binding_sequences"
|
||||
target_col: "receptor"
|
||||
binder_col: "peptide"
|
||||
|
||||
# Scoring weights for final complex quality
|
||||
metric_weights:
|
||||
plddt: 0.3
|
||||
ptm: 0.3
|
||||
iptm: 0.4
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
"""AlphaFold2 API integration for NVIDIA NIM."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
|
|
@ -10,7 +8,6 @@ from pathlib import Path
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default URL
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
|
@ -26,12 +23,12 @@ async def call_alphafold2(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 600, # Increased timeout
|
||||
max_retries: int = 3 # Added retry parameter
|
||||
timeout: int = 600,
|
||||
max_retries: int = 3
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2 API.
|
||||
|
||||
|
||||
Args:
|
||||
sequence: Protein sequence in one-letter code
|
||||
api_key: NVIDIA NIM API key
|
||||
|
|
@ -45,18 +42,16 @@ async def call_alphafold2(
|
|||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
API response or None on failure
|
||||
"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"NVCF-POLL-SECONDS": "300",
|
||||
}
|
||||
|
||||
# Prepare payload
|
||||
|
||||
data = {
|
||||
"sequence": sequence,
|
||||
"algorithm": algorithm,
|
||||
|
|
@ -66,7 +61,7 @@ async def call_alphafold2(
|
|||
"relax_prediction": relax_prediction,
|
||||
"skip_template_search": skip_template_search
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -75,11 +70,9 @@ async def call_alphafold2(
|
|||
headers=headers,
|
||||
timeout=timeout
|
||||
) as response:
|
||||
# Check status code
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Asynchronous job, get job ID
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"AlphaFold2 job submitted, request ID: {req_id}")
|
||||
|
|
@ -103,7 +96,7 @@ async def call_alphafold2(
|
|||
logger.error(f"Error calling AlphaFold2 API: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
|
|
@ -113,14 +106,14 @@ async def _poll_job_status(
|
|||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
||||
|
||||
Args:
|
||||
req_id: The request ID to check
|
||||
headers: Request headers
|
||||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
The final response or None on failure
|
||||
"""
|
||||
|
|
@ -133,11 +126,9 @@ async def _poll_job_status(
|
|||
timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Job completed
|
||||
logger.info(f"AlphaFold2 job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Job still running
|
||||
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
|
|
@ -147,4 +138,4 @@ async def _poll_job_status(
|
|||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling AlphaFold2 job status: {e}")
|
||||
return None
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
"""AlphaFold2-Multimer API integration for NVIDIA NIM."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
|
|
@ -7,16 +5,12 @@ import json
|
|||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import zipfile
|
||||
import io
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default URL
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
# Helper functions
|
||||
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
||||
"""
|
||||
Splits a string containing concatenated PDB file contents.
|
||||
|
|
@ -32,12 +26,12 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
if line.startswith("ENDMDL") or line.startswith("END "):
|
||||
pdbs.append("".join(current_pdb_lines).strip())
|
||||
current_pdb_lines = []
|
||||
|
||||
|
||||
if current_pdb_lines:
|
||||
remaining_pdb = "".join(current_pdb_lines).strip()
|
||||
if remaining_pdb:
|
||||
pdbs.append(remaining_pdb)
|
||||
|
||||
|
||||
return [pdb for pdb in pdbs if pdb]
|
||||
|
||||
|
||||
|
|
@ -73,139 +67,102 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
|
|||
plddt_scores_per_chain[chain_id].append(plddt_value)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
pass
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
|
||||
if ca_atom_count == 0:
|
||||
return 0.0, [], {}
|
||||
|
||||
|
||||
average_plddt = total_plddt / ca_atom_count
|
||||
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
||||
|
||||
async def _process_nvidia_zip_output(
|
||||
zip_content: bytes,
|
||||
output_prefix: str,
|
||||
response_headers: Optional[Dict[str, str]] = None
|
||||
async def _process_pdb_and_scores_from_api(
|
||||
pdb_contents: List[str],
|
||||
job_id: str,
|
||||
api_response_json: Optional[Dict[str, Any]] = None,
|
||||
output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Processes the ZIP file content from NVIDIA NIM.
|
||||
- Expects a .response file with concatenated PDBs, or individual PDB files.
|
||||
- Extracts PDBs and calculates pLDDT scores for each structure.
|
||||
- Returns paths to saved files and calculated pLDDT scores.
|
||||
Processes a list of PDB strings received from the API JSON response.
|
||||
- The API responds with a direct list of PDB strings, not a nested JSON structure
|
||||
- This function saves each PDB to disk and calculates pLDDT scores
|
||||
- The api_response_json parameter is for potential future use if the API adds metadata
|
||||
- The output_dir parameter allows for customizing where files are saved
|
||||
"""
|
||||
output_dir = Path(f"./{output_prefix}_results")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if output_dir is None:
|
||||
# Default behavior if no output dir is provided
|
||||
output_dir_name = f"alphafold2_multimer_output_{job_id}_results"
|
||||
output_dir = Path(f"./{output_dir_name}")
|
||||
|
||||
# Save the original ZIP file
|
||||
zip_file_path = output_dir / f"{output_prefix}.zip"
|
||||
with open(zip_file_path, 'wb') as f:
|
||||
f.write(zip_content)
|
||||
logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}")
|
||||
# Make sure the directory exists
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Saving AlphaFold2-Multimer results for job {job_id} to directory: {output_dir}")
|
||||
|
||||
# Initialize the results dictionary that call_alphafold2_multimer will return
|
||||
results: Dict[str, Any] = {
|
||||
"zip_file_path": str(zip_file_path), # Path to the saved original ZIP
|
||||
"structures": [], # List to hold info for each PDB structure found
|
||||
# We will NOT be trying to parse iptm/ptm from ranking_debug.json
|
||||
"iptm_score": None, # Explicitly None, or remove if not needed at all
|
||||
"ptm_score": None, # Explicitly None, or remove
|
||||
# Optional: Store path to the .response file if it exists and is processed
|
||||
"extracted_response_file_path": None,
|
||||
"output_directory": str(output_dir),
|
||||
"structures": [],
|
||||
"ptm_score": None,
|
||||
"iptm_score": None,
|
||||
}
|
||||
|
||||
pdb_strings_to_process = []
|
||||
if isinstance(api_response_json, dict):
|
||||
logger.info(f"Attempting to extract additional scores from API JSON response for job {job_id}.")
|
||||
results["ptm_score"] = api_response_json.get("ptm")
|
||||
results["iptm_score"] = api_response_json.get("iptm")
|
||||
if results["ptm_score"] is not None or results["iptm_score"] is not None:
|
||||
logger.info(f"Extracted from API JSON: pTM={results['ptm_score']}, ipTM={results['iptm_score']}")
|
||||
else:
|
||||
logger.info(f"No additional API JSON dictionary provided or it's not a dict for job {job_id}; ptm/iptm scores will be None unless found elsewhere.")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
|
||||
# First, check for a ".response" file with concatenated PDBs
|
||||
response_file_name = None
|
||||
for member_name in zf.namelist():
|
||||
if member_name.lower().endswith(".response"):
|
||||
response_file_name = member_name
|
||||
break
|
||||
|
||||
if response_file_name:
|
||||
logger.info(f"Found concatenated response file in ZIP: {response_file_name}")
|
||||
response_data = zf.read(response_file_name)
|
||||
response_content_str = response_data.decode('utf-8', errors='replace')
|
||||
|
||||
# Save the raw .response file
|
||||
extracted_response_file_path = output_dir / Path(response_file_name).name
|
||||
with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp:
|
||||
f_resp.write(response_content_str)
|
||||
results["extracted_response_file_path"] = str(extracted_response_file_path)
|
||||
logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}")
|
||||
|
||||
pdb_strings_to_process.extend(_split_pdb_content(response_content_str))
|
||||
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
|
||||
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
|
||||
return results # Return with empty structures list
|
||||
|
||||
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
|
||||
|
||||
for i, pdb_str in enumerate(pdb_contents):
|
||||
if not pdb_str.strip():
|
||||
logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.")
|
||||
continue
|
||||
|
||||
structure_data: Dict[str, Any] = {
|
||||
"model_index": i,
|
||||
"pdb_content": pdb_str
|
||||
}
|
||||
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
||||
|
||||
structure_data["average_plddt"] = avg_plddt
|
||||
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue
|
||||
structure_data["plddt_scores_per_chain"] = plddts_by_chain
|
||||
|
||||
avg_plddt_per_chain = {}
|
||||
for chain_id, chain_plddts in plddts_by_chain.items():
|
||||
if chain_plddts:
|
||||
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
|
||||
else:
|
||||
# If no ".response" file, look for individual PDB files
|
||||
logger.info("No .response file found. Looking for individual .pdb files in ZIP.")
|
||||
for member_name in zf.namelist():
|
||||
if member_name.lower().endswith(".pdb"):
|
||||
logger.info(f"Found individual PDB file in ZIP: {member_name}")
|
||||
pdb_content_bytes = zf.read(member_name)
|
||||
pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace'))
|
||||
|
||||
if not pdb_strings_to_process:
|
||||
logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).")
|
||||
return results # Return with empty structures list
|
||||
avg_plddt_per_chain[chain_id] = 0.0
|
||||
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
||||
|
||||
logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.")
|
||||
|
||||
for i, pdb_str in enumerate(pdb_strings_to_process):
|
||||
if not pdb_str.strip(): # Skip empty PDB strings
|
||||
logger.debug(f"Skipping empty PDB string at index {i}.")
|
||||
continue
|
||||
pdb_file_name_stem = f"alphafold2_multimer_output_{job_id}"
|
||||
rank_suffix = f"_model_{i+1}"
|
||||
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
|
||||
|
||||
structure_data: Dict[str, Any] = {
|
||||
"model_index": i, # 0-indexed based on order found
|
||||
"pdb_content": pdb_str # Store the raw PDB string
|
||||
}
|
||||
|
||||
# Calculate pLDDT scores using your existing function
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
||||
|
||||
structure_data["average_plddt"] = avg_plddt
|
||||
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs
|
||||
structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT]
|
||||
|
||||
# Calculate average pLDDT for each chain (already in your previous code, good to keep)
|
||||
avg_plddt_per_chain = {}
|
||||
for chain_id, chain_plddts in plddts_by_chain.items():
|
||||
if chain_plddts: # Avoid division by zero
|
||||
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
|
||||
else:
|
||||
avg_plddt_per_chain[chain_id] = 0.0
|
||||
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
||||
try:
|
||||
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
|
||||
f_pdb.write(pdb_str)
|
||||
structure_data["saved_pdb_path"] = str(pdb_file_path)
|
||||
logger.info(f"Saved PDB model {i+1} for job {job_id} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
|
||||
except Exception as e_write:
|
||||
logger.error(f"Failed to write PDB file {pdb_file_path} for job {job_id}: {e_write}")
|
||||
structure_data["saved_pdb_path"] = None
|
||||
|
||||
# Save the individual PDB string to a file
|
||||
pdb_file_name_stem = Path(output_prefix).stem
|
||||
# Suffix for rank if multiple models found, otherwise simpler name
|
||||
rank_suffix = f"_model_{i}" # Consistent naming for multiple models
|
||||
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
|
||||
|
||||
try:
|
||||
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
|
||||
f_pdb.write(pdb_str)
|
||||
structure_data["saved_pdb_path"] = str(pdb_file_path)
|
||||
logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
|
||||
except Exception as e_write:
|
||||
logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}")
|
||||
structure_data["saved_pdb_path"] = None
|
||||
|
||||
results["structures"].append(structure_data)
|
||||
results["structures"].append(structure_data)
|
||||
|
||||
if results["structures"]:
|
||||
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
|
||||
|
||||
if results["structures"]:
|
||||
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.")
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.")
|
||||
# results dictionary is already initialized, will be returned as is, potentially empty.
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True)
|
||||
# As above, return results which might be partially filled or empty.
|
||||
|
||||
return results
|
||||
|
||||
async def call_alphafold2_multimer(
|
||||
|
|
@ -220,16 +177,18 @@ async def call_alphafold2_multimer(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 30,
|
||||
timeout: int = 3600
|
||||
timeout: int = 3600,
|
||||
output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
||||
Returns a dictionary structured by _process_nvidia_zip_output.
|
||||
The API now returns JSON with a list of PDB strings, which we process to calculate pLDDT scores.
|
||||
Returns a dictionary with processed PDB strings and computed scores.
|
||||
"""
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"NVCF-POLL-SECONDS": "300",
|
||||
"NVCF-POLL-SECONDS": "300",
|
||||
}
|
||||
data: Dict[str, Any] = {
|
||||
"sequences": sequences,
|
||||
|
|
@ -242,9 +201,9 @@ async def call_alphafold2_multimer(
|
|||
if selected_models is not None:
|
||||
data["selected_models"] = selected_models
|
||||
logger.info(f"Using selected_models: {selected_models}")
|
||||
|
||||
|
||||
try:
|
||||
initial_post_timeout = min(timeout, 600)
|
||||
initial_post_timeout = min(timeout, 600)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
|
|
@ -252,15 +211,26 @@ async def call_alphafold2_multimer(
|
|||
headers=headers,
|
||||
timeout=initial_post_timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
if response.status == 200:
|
||||
logger.info("AlphaFold2-Multimer job completed synchronously.")
|
||||
content = await response.read()
|
||||
return await _process_nvidia_zip_output(
|
||||
zip_content=content,
|
||||
output_prefix="alphafold2_multimer_sync_output",
|
||||
response_headers=response.headers
|
||||
)
|
||||
elif response.status == 202:
|
||||
logger.info(f"SYNC Final response headers: {response.headers}")
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
logger.info(f"SYNC Final response content-type: {content_type}")
|
||||
|
||||
if "application/json" in content_type:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
|
||||
req_id_sync = response.headers.get("nvcf-reqid", "sync_job") # Get req_id or make one up
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id_sync,
|
||||
api_response_json=None,
|
||||
output_dir=output_dir
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"Sync response unexpected content type: {content_type}"}
|
||||
elif response.status == 202:
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
|
||||
|
|
@ -273,47 +243,44 @@ async def call_alphafold2_multimer(
|
|||
)
|
||||
else:
|
||||
logger.error("No request ID in 202 response headers")
|
||||
return None
|
||||
else: # Handle other error statuses from POST
|
||||
return {"success": False, "error": "No request ID in 202 response headers"}
|
||||
else:
|
||||
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
return None
|
||||
return {"success": False, "error": f"Error calling API: {response.status}", "detail": text}
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
|
||||
return None
|
||||
return {"success": False, "error": "Timeout during initial API request"}
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
return {"success": False, "error": f"Exception during API call: {str(e)}"}
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
status_url: str,
|
||||
polling_interval: int = 30,
|
||||
overall_timeout: int = 3600
|
||||
polling_interval: int = 30,
|
||||
overall_timeout: int = 3600
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
# Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it,
|
||||
# or a fixed reasonably long duration.
|
||||
# The NVCF-POLL-SECONDS in the POST header is 300s.
|
||||
# The GET request to /status should also ideally respect a similar long-poll duration from the server.
|
||||
status_check_timeout = 330 # seconds (e.g., 5.5 minutes)
|
||||
logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
||||
per_status_request_timeout = 600
|
||||
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
||||
|
||||
while True:
|
||||
current_loop_time = asyncio.get_event_loop().time()
|
||||
elapsed_time = current_loop_time - start_time
|
||||
|
||||
if elapsed_time > overall_timeout:
|
||||
|
||||
if elapsed_time >= overall_timeout:
|
||||
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
|
||||
return None
|
||||
|
||||
return {"success": False, "error": "Overall polling timeout exceeded."}
|
||||
|
||||
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
|
||||
current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout)
|
||||
current_status_check_timeout = min(per_status_request_timeout, remaining_time_for_overall_timeout)
|
||||
|
||||
if current_status_check_timeout <= 0:
|
||||
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
|
||||
return None
|
||||
return {"success": False, "error": "Not enough time for status check within overall timeout."}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
|
|
@ -321,46 +288,62 @@ async def _poll_job_status(
|
|||
async with session.get(
|
||||
f"{status_url}/{req_id}",
|
||||
headers=headers,
|
||||
timeout=current_status_check_timeout
|
||||
timeout=current_status_check_timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
|
||||
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
|
||||
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
|
||||
zip_content_bytes = await response.read()
|
||||
return await _process_nvidia_zip_output(
|
||||
zip_content=zip_content_bytes,
|
||||
output_prefix=f"alphafold2_multimer_output_{req_id}",
|
||||
response_headers=response.headers
|
||||
)
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
try:
|
||||
api_response_json_payload = await response.json()
|
||||
logger.debug(f"API JSON Response for job {req_id}: {str(api_response_json_payload)[:500]}...")
|
||||
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
|
||||
return {"success": False, "error": "API response was not a list of PDB strings."}
|
||||
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id,
|
||||
api_response_json=None,
|
||||
output_dir=output_dir
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
|
||||
raw_text = await response.text()
|
||||
logger.debug(f"Raw text response: {raw_text[:500]}")
|
||||
return {"success": False, "error": "Failed to decode JSON response."}
|
||||
else:
|
||||
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json.")
|
||||
raw_text = await response.text()
|
||||
logger.debug(f"Raw text response: {raw_text[:500]}")
|
||||
return {"success": False, "error": f"Unexpected content type: {response.content_type}"}
|
||||
|
||||
elif response.status == 202:
|
||||
try:
|
||||
job_status_json = await response.json()
|
||||
job_status_json = await response.json()
|
||||
percent_complete = job_status_json.get('percentComplete', 'N/A')
|
||||
status_message = job_status_json.get('status', 'running')
|
||||
logger.debug(
|
||||
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
|
||||
)
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
logger.debug(
|
||||
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
else: # Handle other error statuses from status GET
|
||||
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}")
|
||||
else:
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
# Log the error, but continue polling unless it's a fatal client error (4xx other than 429)
|
||||
# or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body)
|
||||
# For a 504 like you saw, we might want to retry a few times then give up.
|
||||
# For now, this will return None on non-200/202, which your test script will catch.
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: HTTP {response.status} - {text}")
|
||||
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
|
||||
await asyncio.sleep(polling_interval)
|
||||
except aiohttp.ClientError as e:
|
||||
await asyncio.sleep(polling_interval)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
|
||||
await asyncio.sleep(polling_interval)
|
||||
except Exception as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
|
||||
return None
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
"""ProteinMPNN API integration for NVIDIA NIM."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
|
|
@ -10,7 +8,6 @@ from pathlib import Path
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default URL
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
|
@ -27,7 +24,7 @@ async def call_proteinmpnn(
|
|||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM ProteinMPNN API.
|
||||
|
||||
|
||||
Args:
|
||||
input_pdb: PDB structure as a string
|
||||
api_key: NVIDIA NIM API key
|
||||
|
|
@ -38,25 +35,23 @@ async def call_proteinmpnn(
|
|||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
API response or None on failure
|
||||
"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"NVCF-POLL-SECONDS": "300",
|
||||
}
|
||||
|
||||
# Prepare payload
|
||||
|
||||
data = {
|
||||
"input_pdb": input_pdb,
|
||||
"ca_only": ca_only,
|
||||
"use_soluble_model": use_soluble_model,
|
||||
"sampling_temp": sampling_temp
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -65,11 +60,9 @@ async def call_proteinmpnn(
|
|||
headers=headers,
|
||||
timeout=timeout
|
||||
) as response:
|
||||
# Check status code
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Asynchronous job, get job ID
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"ProteinMPNN job submitted, request ID: {req_id}")
|
||||
|
|
@ -91,7 +84,7 @@ async def call_proteinmpnn(
|
|||
except Exception as e:
|
||||
logger.error(f"Error calling ProteinMPNN API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
|
|
@ -101,14 +94,14 @@ async def _poll_job_status(
|
|||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
||||
|
||||
Args:
|
||||
req_id: The request ID to check
|
||||
headers: Request headers
|
||||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
The final response or None on failure
|
||||
"""
|
||||
|
|
@ -121,11 +114,9 @@ async def _poll_job_status(
|
|||
timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Job completed
|
||||
logger.info(f"ProteinMPNN job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Job still running
|
||||
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
|
|
@ -135,4 +126,4 @@ async def _poll_job_status(
|
|||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling ProteinMPNN job status: {e}")
|
||||
return None
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
"""RFDiffusion API integration for NVIDIA NIM."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
|
|
@ -10,7 +8,6 @@ from pathlib import Path
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default URL
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
|
@ -27,7 +24,7 @@ async def call_rfdiffusion(
|
|||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM RFDiffusion API.
|
||||
|
||||
|
||||
Args:
|
||||
input_pdb: PDB structure as a string
|
||||
api_key: NVIDIA NIM API key
|
||||
|
|
@ -38,29 +35,26 @@ async def call_rfdiffusion(
|
|||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
API response or None on failure
|
||||
"""
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"NVCF-POLL-SECONDS": "300",
|
||||
}
|
||||
|
||||
# Prepare payload
|
||||
|
||||
data = {
|
||||
"input_pdb": input_pdb,
|
||||
"diffusion_steps": diffusion_steps
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
|
||||
if contigs:
|
||||
data["contigs"] = contigs
|
||||
if hotspot_res:
|
||||
data["hotspot_res"] = hotspot_res
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
|
|
@ -69,11 +63,9 @@ async def call_rfdiffusion(
|
|||
headers=headers,
|
||||
timeout=timeout
|
||||
) as response:
|
||||
# Check status code
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Asynchronous job, get job ID
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"RFDiffusion job submitted, request ID: {req_id}")
|
||||
|
|
@ -95,7 +87,7 @@ async def call_rfdiffusion(
|
|||
except Exception as e:
|
||||
logger.error(f"Error calling RFDiffusion API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
|
|
@ -105,14 +97,14 @@ async def _poll_job_status(
|
|||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
||||
|
||||
Args:
|
||||
req_id: The request ID to check
|
||||
headers: Request headers
|
||||
status_url: Status URL for checking job completion
|
||||
polling_interval: Seconds between status checks
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
|
||||
Returns:
|
||||
The final response or None on failure
|
||||
"""
|
||||
|
|
@ -125,11 +117,9 @@ async def _poll_job_status(
|
|||
timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Job completed
|
||||
logger.info(f"RFDiffusion job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
# Job still running
|
||||
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
|
|
@ -139,4 +129,4 @@ async def _poll_job_status(
|
|||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling RFDiffusion job status: {e}")
|
||||
return None
|
||||
return None
|
||||
|
|
|
|||
136
environments/hack0/protein_design_env/prompts.py
Normal file
136
environments/hack0/protein_design_env/prompts.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein.
|
||||
|
||||
You will be guided through a multi-step pipeline:
|
||||
|
||||
1. Structure prediction (AlphaFold)
|
||||
2. Binder backbone generation (RFdiffusion)
|
||||
3. Sequence design (ProteinMPNN)
|
||||
4. Structure evaluation (AlphaFold-Multimer)
|
||||
5. Feedback loop
|
||||
|
||||
You must always:
|
||||
- Respect the required file format for each tool (e.g., FASTA, PDB).
|
||||
- Structure your outputs cleanly so they can be parsed and executed programmatically.
|
||||
- Be explicit in all configuration steps (e.g., contigs, hotspots).
|
||||
- Minimize ambiguity or verbosity; prefer concise and functional outputs.
|
||||
- Reason step-by-step when appropriate.
|
||||
"""
|
||||
|
||||
def construct_user_prompt(state: dict) -> str:
|
||||
"""
|
||||
Constructs the appropriate user prompt for the current internal step.
|
||||
|
||||
Args:
|
||||
state: The current workflow state dictionary (from episodes_state)
|
||||
|
||||
Returns:
|
||||
A formatted user prompt string for the current step
|
||||
"""
|
||||
internal_step = state.get("current_internal_step", 0)
|
||||
target_sequence = state.get("target_sequence")
|
||||
user_prompt_str = ""
|
||||
|
||||
if internal_step == 0:
|
||||
user_prompt_str = (
|
||||
f"The target protein sequence is: {target_sequence}. "
|
||||
"Your first task is to predict its 3D structure using the 'predict_target_structure_alphafold2' tool. "
|
||||
"You must provide the 'sequence' argument."
|
||||
)
|
||||
elif internal_step == 1:
|
||||
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
|
||||
|
||||
chain_details = state.get("target_chain_details", {})
|
||||
if chain_details:
|
||||
chain_info_parts = []
|
||||
for chain_id, details in chain_details.items():
|
||||
min_r = details.get('min_residue', 'N/A')
|
||||
max_r = details.get('max_residue', 'N/A')
|
||||
l = details.get('length', 'N/A')
|
||||
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
|
||||
chain_info_str = "\n- ".join(chain_info_parts)
|
||||
if chain_info_str:
|
||||
chain_info_str = "- " + chain_info_str
|
||||
else:
|
||||
chain_info_str = "Chain information not available or PDB not yet processed."
|
||||
|
||||
user_prompt_str = (
|
||||
f"The 3D structure of the target protein has been predicted.\n"
|
||||
f"Target Protein Chain Details:\n{chain_info_str}\n\n"
|
||||
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
|
||||
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
|
||||
"Examples:\n"
|
||||
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
|
||||
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
|
||||
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
|
||||
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
|
||||
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
|
||||
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
|
||||
)
|
||||
elif internal_step == 2:
|
||||
binder_pdb_content = state.get("binder_backbone_pdb_content")
|
||||
|
||||
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.")
|
||||
binder_chain_info_str = "Binder chain information not available."
|
||||
|
||||
if binder_pdb_content:
|
||||
binder_chain_details = state.get("binder_chain_details", {})
|
||||
|
||||
if binder_chain_details:
|
||||
chain_info_parts = []
|
||||
for cID, d_details in binder_chain_details.items():
|
||||
min_r = d_details.get('min_residue', 'N/A')
|
||||
max_r = d_details.get('max_residue', 'N/A')
|
||||
l = d_details.get('length', 'N/A')
|
||||
chain_info_parts.append(f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
|
||||
binder_chain_info_str = "\n- ".join(chain_info_parts)
|
||||
if binder_chain_info_str:
|
||||
binder_chain_info_str = "- " + binder_chain_info_str
|
||||
else:
|
||||
binder_chain_info_str = "Binder chain details not found in state (expected from RFDiffusion)."
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
user_prompt_str = (
|
||||
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
|
||||
f"Binder chain information:\n{binder_chain_info_str}.\n"
|
||||
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
|
||||
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
|
||||
)
|
||||
elif internal_step == 3:
|
||||
designed_binder_seq_data = state.get("designed_binder_sequence")
|
||||
|
||||
binder_display_str = "Not available"
|
||||
if isinstance(designed_binder_seq_data, list) and designed_binder_seq_data:
|
||||
if len(designed_binder_seq_data) == 1:
|
||||
binder_display_str = designed_binder_seq_data[0]
|
||||
else:
|
||||
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
|
||||
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
|
||||
for i, s in enumerate(designed_binder_seq_data)])
|
||||
elif isinstance(designed_binder_seq_data, str):
|
||||
binder_display_str = designed_binder_seq_data
|
||||
|
||||
user_prompt_str = (
|
||||
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
|
||||
f"The original target sequence was: {target_sequence[:60]}...\n"
|
||||
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
|
||||
"'evaluate_binder_complex_alphafold2_multimer' tool. "
|
||||
"You can optionally specify 'relax_prediction' (default is True)."
|
||||
)
|
||||
else:
|
||||
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
|
||||
|
||||
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
|
||||
retry_prefix = "Your previous attempt at this step was not successful. "
|
||||
if state.get("previous_tool_error_message"):
|
||||
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
|
||||
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
|
||||
user_prompt_str = retry_prefix + user_prompt_str
|
||||
|
||||
return user_prompt_str
|
||||
File diff suppressed because it is too large
Load diff
67
environments/hack0/protein_design_env/tool_definitions.py
Normal file
67
environments/hack0/protein_design_env/tool_definitions.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
PREDICT_TARGET_STRUCTURE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "predict_target_structure_alphafold2",
|
||||
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sequence": {"type": "string", "description": "Amino acid sequence of the target protein."},
|
||||
},
|
||||
"required": ["sequence"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DESIGN_BINDER_BACKBONE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_backbone_rfdiffusion",
|
||||
"description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."},
|
||||
"hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."},
|
||||
},
|
||||
"required": ["contigs"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DESIGN_BINDER_SEQUENCE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_sequence_proteinmpnn",
|
||||
"description": "Designs an amino acid sequence for the generated binder backbone.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EVALUATE_COMPLEX_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "evaluate_binder_complex_alphafold2_multimer",
|
||||
"description": "Predicts the complex structure of target and designed binder, providing quality scores.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ALL_TOOLS_LIST = [
|
||||
PREDICT_TARGET_STRUCTURE_TOOL,
|
||||
DESIGN_BINDER_BACKBONE_TOOL,
|
||||
DESIGN_BINDER_SEQUENCE_TOOL,
|
||||
EVALUATE_COMPLEX_TOOL
|
||||
]
|
||||
472
environments/hack0/protein_design_env/tool_executor.py
Normal file
472
environments/hack0/protein_design_env/tool_executor.py
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||
from pathlib import Path
|
||||
from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2
|
||||
from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion
|
||||
from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn
|
||||
from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer
|
||||
from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(self, nim_api_key: str, api_timeout: int, polling_interval: int,
|
||||
output_dir: Path, debug_protein_design_calls: bool):
|
||||
self.nim_api_key = nim_api_key
|
||||
self.api_timeout = api_timeout
|
||||
self.polling_interval = polling_interval
|
||||
self.output_dir = output_dir
|
||||
self.debug_protein_design_calls = debug_protein_design_calls
|
||||
self._debug_af2m_call_count = 0
|
||||
|
||||
def _validate_rfd_contigs(self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
|
||||
"""
|
||||
Validates the RFDiffusion contigs string against target PDB chain details.
|
||||
Returns None if valid, or an error message string if invalid.
|
||||
"""
|
||||
if not contigs_str: return "Contigs string is empty."
|
||||
|
||||
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
|
||||
active_contig_parts = contigs_str.split('/')
|
||||
|
||||
for part in active_contig_parts:
|
||||
chain_segments_in_part = part.strip().split(' ')
|
||||
for segment_text in chain_segments_in_part:
|
||||
segment_text = segment_text.strip()
|
||||
if not segment_text or segment_text.isdigit():
|
||||
continue
|
||||
|
||||
match = target_segment_pattern.fullmatch(segment_text)
|
||||
if match:
|
||||
seg_chain_id, seg_start_str, seg_end_str = match.groups()
|
||||
seg_start = int(seg_start_str)
|
||||
seg_end = int(seg_end_str)
|
||||
|
||||
if seg_chain_id not in target_chain_details:
|
||||
return f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. Valid: {list(target_chain_details.keys())}."
|
||||
|
||||
chain_min = target_chain_details[seg_chain_id]["min_residue"]
|
||||
chain_max = target_chain_details[seg_chain_id]["max_residue"]
|
||||
|
||||
if not (chain_min <= seg_start <= chain_max and chain_min <= seg_end <= chain_max and seg_start <= seg_end):
|
||||
return (f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
|
||||
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}.")
|
||||
return None
|
||||
|
||||
def _validate_rfd_hotspots(self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
|
||||
"""
|
||||
Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details.
|
||||
Returns None if valid, or an error message string if invalid.
|
||||
"""
|
||||
if not hotspot_list: return None
|
||||
|
||||
hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)")
|
||||
|
||||
for hotspot_str in hotspot_list:
|
||||
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
|
||||
if not match:
|
||||
return f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
|
||||
|
||||
hs_chain_id, hs_res_num_str = match.groups()
|
||||
hs_res_num = int(hs_res_num_str)
|
||||
|
||||
if hs_chain_id not in target_chain_details:
|
||||
return f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. Valid: {list(target_chain_details.keys())}."
|
||||
|
||||
chain_min = target_chain_details[hs_chain_id]["min_residue"]
|
||||
chain_max = target_chain_details[hs_chain_id]["max_residue"]
|
||||
|
||||
if not (chain_min <= hs_res_num <= chain_max):
|
||||
return (f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
|
||||
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}.")
|
||||
return None
|
||||
|
||||
async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs AlphaFold2 for target structure prediction. Returns structured output with
|
||||
tool_output and state_updates separated.
|
||||
"""
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
target_sequence_from_state = workflow_state["target_sequence"]
|
||||
|
||||
tool_output = {}
|
||||
state_updates = {}
|
||||
|
||||
if self.debug_protein_design_calls:
|
||||
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
fixed_pdb_path = project_root / "binder_outputs" / "target.pdb"
|
||||
|
||||
if not fixed_pdb_path.exists():
|
||||
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
|
||||
tool_output = {"success": False, "error": f"Debug mode failed: Required file {fixed_pdb_path} not found."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
with open(fixed_pdb_path, "r") as f:
|
||||
pdb_content = f.read()
|
||||
|
||||
chain_details, pdb_preview = get_pdb_chain_details(pdb_content)
|
||||
|
||||
state_updates["target_pdb_content"] = pdb_content
|
||||
state_updates["target_chain_details"] = chain_details
|
||||
state_updates["target_pdb_preview"] = pdb_preview
|
||||
state_updates["target_structure_predicted"] = True
|
||||
|
||||
debug_pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
|
||||
with open(debug_pdb_path, "w") as f:
|
||||
f.write(pdb_content)
|
||||
logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}")
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": "DEBUG MODE: Used fixed PDB for AlphaFold2.",
|
||||
"target_pdb_preview": pdb_preview,
|
||||
"saved_pdb_path": str(debug_pdb_path)
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
sequence_from_llm = args.get("sequence")
|
||||
if not sequence_from_llm:
|
||||
tool_output = {"success": False, "error": "Missing 'sequence' for AlphaFold2."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
actual_sequence_to_use = target_sequence_from_state
|
||||
if sequence_from_llm != target_sequence_from_state:
|
||||
logger.warning(
|
||||
f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. "
|
||||
f"However, this tool will use the canonical target sequence from the workflow state: '{target_sequence_from_state[:20]}...'"
|
||||
)
|
||||
|
||||
api_result = await call_alphafold2(
|
||||
sequence=actual_sequence_to_use, api_key=self.nim_api_key,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
)
|
||||
if api_result and isinstance(api_result, list) and api_result[0]:
|
||||
pdb_content = api_result[0]
|
||||
chain_details, pdb_preview = get_pdb_chain_details(pdb_content)
|
||||
|
||||
state_updates["target_pdb_content"] = pdb_content
|
||||
state_updates["target_chain_details"] = chain_details
|
||||
state_updates["target_pdb_preview"] = pdb_preview
|
||||
state_updates["target_structure_predicted"] = True
|
||||
|
||||
pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
|
||||
with open(pdb_path, "w") as f: f.write(pdb_content)
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
|
||||
|
||||
tool_output = {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview, "saved_pdb_path": str(pdb_path)}
|
||||
else:
|
||||
error_detail = api_result.get("error", "AlphaFold2 prediction failed.") if isinstance(api_result, dict) else "AlphaFold2 prediction failed."
|
||||
logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["target_structure_predicted"] = False
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs RFDiffusion for binder backbone design. Returns structured output with
|
||||
tool_output and state_updates separated.
|
||||
"""
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
target_pdb_content = workflow_state.get("target_pdb_content")
|
||||
target_chain_details = workflow_state.get("target_chain_details", {})
|
||||
|
||||
tool_output = {}
|
||||
state_updates = {}
|
||||
|
||||
contigs_str_from_llm = args.get("contigs")
|
||||
if not target_pdb_content:
|
||||
tool_output = {"success": False, "error": "Target PDB not found in state for RFDiffusion."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
if not contigs_str_from_llm:
|
||||
tool_output = {"success": False, "error": "Missing 'contigs' for RFDiffusion."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
validation_error = self._validate_rfd_contigs(contigs_str_from_llm, target_chain_details)
|
||||
if validation_error:
|
||||
logger.warning(f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. Contigs: '{contigs_str_from_llm}'")
|
||||
tool_output = {"success": False, "error": f"Invalid contigs: {validation_error}"}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
hotspot_residues = args.get("hotspot_residues")
|
||||
if hotspot_residues:
|
||||
hotspot_validation_error = self._validate_rfd_hotspots(hotspot_residues, target_chain_details)
|
||||
if hotspot_validation_error:
|
||||
logger.warning(f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. Hotspots: {hotspot_residues}")
|
||||
tool_output = {"success": False, "error": f"Invalid hotspots: {hotspot_validation_error}"}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
api_result = await call_rfdiffusion(
|
||||
input_pdb=target_pdb_content, api_key=self.nim_api_key,
|
||||
contigs=contigs_str_from_llm, hotspot_res=hotspot_residues,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
)
|
||||
|
||||
if api_result and "output_pdb" in api_result:
|
||||
binder_pdb = api_result["output_pdb"]
|
||||
binder_chain_details, binder_pdb_preview = get_pdb_chain_details(binder_pdb)
|
||||
|
||||
state_updates["binder_backbone_pdb_content"] = binder_pdb
|
||||
state_updates["binder_chain_details"] = binder_chain_details
|
||||
state_updates["binder_pdb_preview"] = binder_pdb_preview
|
||||
state_updates["binder_backbone_designed"] = True
|
||||
|
||||
pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
|
||||
with open(pdb_path, "w") as f: f.write(binder_pdb)
|
||||
logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}")
|
||||
|
||||
tool_output = {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb_preview, "saved_pdb_path": str(pdb_path)}
|
||||
else:
|
||||
error_detail = api_result.get("error", "RFDiffusion failed.") if isinstance(api_result, dict) else "RFDiffusion failed."
|
||||
logger.error(f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["binder_backbone_designed"] = False
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs ProteinMPNN for binder sequence design. Returns structured output with
|
||||
tool_output and state_updates separated.
|
||||
"""
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
binder_pdb = workflow_state.get("binder_backbone_pdb_content")
|
||||
|
||||
tool_output = {}
|
||||
state_updates = {}
|
||||
|
||||
if not binder_pdb:
|
||||
tool_output = {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
sampling_temp_list = args.get("sampling_temp", [0.1])
|
||||
|
||||
api_result = await call_proteinmpnn(
|
||||
input_pdb=binder_pdb, api_key=self.nim_api_key,
|
||||
sampling_temp=sampling_temp_list,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
)
|
||||
|
||||
if not (api_result and "mfasta" in api_result):
|
||||
error_detail = api_result.get("error", "ProteinMPNN call failed or no mfasta in result.") if isinstance(api_result, dict) else "PMPNN call failed"
|
||||
logger.error(f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["binder_sequence_designed"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
fasta_content = api_result["mfasta"]
|
||||
entries: List[Tuple[float, str, str]] = []
|
||||
current_header = None
|
||||
current_sequence_parts: List[str] = []
|
||||
for line_content in fasta_content.splitlines():
|
||||
line = line_content.strip()
|
||||
if not line: continue
|
||||
if line.startswith(">"):
|
||||
if current_header and current_sequence_parts:
|
||||
full_sequence_line = "".join(current_sequence_parts)
|
||||
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
||||
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
||||
entries.append((global_score, current_header, full_sequence_line))
|
||||
current_header = line
|
||||
current_sequence_parts = []
|
||||
else:
|
||||
current_sequence_parts.append(line)
|
||||
if current_header and current_sequence_parts:
|
||||
full_sequence_line = "".join(current_sequence_parts)
|
||||
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
||||
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
||||
entries.append((global_score, current_header, full_sequence_line))
|
||||
|
||||
if not entries:
|
||||
tool_output = {"success": False, "error": "No sequences parsed from PMPNN."}
|
||||
state_updates["binder_sequence_designed"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
best_global_score, best_header, best_full_sequence_line = entries[0]
|
||||
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'")
|
||||
|
||||
parsed_binder_chains = [s.strip() for s in best_full_sequence_line.split('/') if s.strip()]
|
||||
|
||||
if not parsed_binder_chains or not all(s and s.isalpha() and s.isupper() for s in parsed_binder_chains):
|
||||
tool_output = {"success": False, "error": f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. Parsed: {parsed_binder_chains}"}
|
||||
state_updates["binder_sequence_designed"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
state_updates["designed_binder_sequence"] = parsed_binder_chains
|
||||
state_updates["binder_sequence_designed"] = True
|
||||
|
||||
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
|
||||
with open(fasta_path, "w") as f: f.write(fasta_content)
|
||||
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
|
||||
|
||||
preview = parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
|
||||
if len(parsed_binder_chains) > 1:
|
||||
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
|
||||
"designed_binder_sequence_list": parsed_binder_chains,
|
||||
"designed_binder_sequence_preview": preview,
|
||||
"saved_fasta_path": str(fasta_path)
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs AlphaFold2-Multimer to evaluate the target-binder complex. Returns structured output
|
||||
with tool_output and state_updates separated.
|
||||
"""
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
target_seq = workflow_state.get("target_sequence")
|
||||
designed_binder_chains_list = workflow_state.get("designed_binder_sequence")
|
||||
|
||||
tool_output = {}
|
||||
state_updates = {}
|
||||
|
||||
if not target_seq or not designed_binder_chains_list or not isinstance(designed_binder_chains_list, list):
|
||||
tool_output = {"success": False, "error": "Missing or invalid sequences for AF2-Multimer."}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list
|
||||
|
||||
for i, seq_to_validate in enumerate(all_input_sequences_for_multimer):
|
||||
if not (seq_to_validate and isinstance(seq_to_validate, str) and seq_to_validate.isalpha() and seq_to_validate.isupper()):
|
||||
error_msg = (f"Sequence {i+1} (part of target/binder complex) is invalid: "
|
||||
f"'{str(seq_to_validate)[:30]}...'. Contains non-alpha/lowercase, is empty, or not a string.")
|
||||
logger.error(f"Workflow {item_id}: {error_msg}")
|
||||
tool_output = {"success": False, "error": error_msg}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
relax = args.get("relax_prediction", True) # Added to use LLM arg
|
||||
|
||||
if self.debug_protein_design_calls:
|
||||
self._debug_af2m_call_count += 1
|
||||
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
|
||||
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})"
|
||||
|
||||
mock_pdb_path = self.output_dir / f"mock_complex_{item_id}_s{current_internal_step}_af2m.pdb"
|
||||
with open(mock_pdb_path, "w") as f:
|
||||
f.write(f"MOCK PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
|
||||
|
||||
state_updates["complex_pdb_content_path"] = str(mock_pdb_path)
|
||||
state_updates["af2_multimer_plddt"] = mock_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
tool_output = {
|
||||
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
|
||||
"plddt": mock_plddt,
|
||||
"complex_file_path": str(mock_pdb_path)
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
output_subdir = self.output_dir / f"alphafold2_multimer_{item_id}_s{current_internal_step}"
|
||||
logger.info(f"Using output directory for AlphaFold2-Multimer results: {output_subdir}")
|
||||
|
||||
api_result = await call_alphafold2_multimer(
|
||||
sequences=all_input_sequences_for_multimer,
|
||||
api_key=self.nim_api_key,
|
||||
relax_prediction=relax,
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval,
|
||||
output_dir=output_subdir
|
||||
)
|
||||
|
||||
if isinstance(api_result, dict):
|
||||
if "success" in api_result and api_result["success"] is False:
|
||||
error_detail = api_result.get("error", "AF2-Multimer call failed with error.")
|
||||
detail_info = api_result.get("detail", "")
|
||||
if detail_info:
|
||||
error_detail += f" Details: {detail_info}"
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
if "structures" in api_result and len(api_result["structures"]) > 0:
|
||||
all_structures_info = api_result["structures"]
|
||||
|
||||
best_structure_info = None
|
||||
highest_plddt = -1.0
|
||||
|
||||
for struct_info in all_structures_info:
|
||||
current_plddt = struct_info.get("average_plddt", 0.0)
|
||||
if current_plddt > highest_plddt:
|
||||
highest_plddt = current_plddt
|
||||
best_structure_info = struct_info
|
||||
|
||||
if best_structure_info is None: # Should not happen if all_structures_info was not empty
|
||||
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results.")
|
||||
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
best_plddt = best_structure_info.get("average_plddt", 0.0)
|
||||
best_pdb_path = best_structure_info.get("saved_pdb_path")
|
||||
best_model_idx = best_structure_info.get("model_index", "N/A")
|
||||
|
||||
state_updates["complex_pdb_content_path"] = best_pdb_path
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models. PDB: {best_pdb_path}")
|
||||
|
||||
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": complex_quality_message,
|
||||
"plddt": best_plddt,
|
||||
"complex_file_path": best_pdb_path,
|
||||
"selected_model_index": best_model_idx
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
error_detail = "AF2-Multimer call failed or returned unexpected data format."
|
||||
if isinstance(api_result, dict) and "error" in api_result:
|
||||
error_detail = api_result["error"]
|
||||
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. Full API Result: {api_result}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["complex_evaluated"] = False
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def dispatch_tool_call(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""Main dispatch method for executing tools."""
|
||||
item_id = workflow_state["item_id"]
|
||||
internal_step = workflow_state["current_internal_step"]
|
||||
logger.info(f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, Step {internal_step} with args: {args}")
|
||||
|
||||
if not self.nim_api_key:
|
||||
return {
|
||||
"tool_output": {"success": False, "error": "NIM API key not configured in ToolExecutor."},
|
||||
"state_updates": {}
|
||||
}
|
||||
|
||||
if tool_name == "predict_target_structure_alphafold2":
|
||||
return await self._run_nim_alphafold2(args, workflow_state)
|
||||
elif tool_name == "design_binder_backbone_rfdiffusion":
|
||||
return await self._run_nim_rfdiffusion(args, workflow_state)
|
||||
elif tool_name == "design_binder_sequence_proteinmpnn":
|
||||
return await self._run_nim_proteinmpnn(args, workflow_state)
|
||||
elif tool_name == "evaluate_binder_complex_alphafold2_multimer":
|
||||
return await self._run_nim_af2_multimer(args, workflow_state)
|
||||
else:
|
||||
logger.error(f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}")
|
||||
return {
|
||||
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"},
|
||||
"state_updates": {}
|
||||
}
|
||||
|
|
@ -1 +1,5 @@
|
|||
"""Utility functions for the protein design environment."""
|
||||
"""Utility functions for the protein design environment."""
|
||||
|
||||
from .pdb_utils import get_pdb_chain_details
|
||||
|
||||
__all__ = ["get_pdb_chain_details"]
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
"""API utility functions for the protein design environment."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import yaml
|
||||
|
|
@ -7,15 +5,13 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def load_api_key() -> Optional[str]:
|
||||
"""
|
||||
Load the NVIDIA NIM API key from environment variables.
|
||||
|
||||
|
||||
Returns:
|
||||
The API key from environment variables, or None if not found
|
||||
"""
|
||||
|
|
@ -24,5 +20,5 @@ def load_api_key() -> Optional[str]:
|
|||
logger.error("NVIDIA_NIM_API_KEY not found in environment variables. "
|
||||
"Please set it in your .env file.")
|
||||
return None
|
||||
|
||||
return api_key
|
||||
|
||||
return api_key
|
||||
|
|
|
|||
65
environments/hack0/protein_design_env/utils/pdb_utils.py
Normal file
65
environments/hack0/protein_design_env/utils/pdb_utils.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import logging
|
||||
from typing import Dict, Tuple, List, Set, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
|
||||
"""
|
||||
Parses PDB content to extract detailed information for each chain.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- chain_details (Dict[str, Dict[str, int]]):
|
||||
A dictionary where keys are chain IDs (e.g., "A").
|
||||
Each value is another dictionary:
|
||||
{
|
||||
"min_residue": int, # Smallest residue number found for this chain
|
||||
"max_residue": int, # Largest residue number found for this chain
|
||||
"length": int # Count of unique C-alpha atoms (residues) in this chain
|
||||
}
|
||||
- pdb_preview (str): A string preview of the PDB content.
|
||||
"""
|
||||
chain_info_temp: Dict[str, Dict[str, Union[Set[int], int]]] = {}
|
||||
atom_lines = []
|
||||
header_lines = []
|
||||
|
||||
for line in pdb_content.splitlines():
|
||||
if line.startswith("ATOM"):
|
||||
atom_lines.append(line)
|
||||
chain_id = line[21:22].strip()
|
||||
if not chain_id: chain_id = " "
|
||||
atom_name = line[12:16].strip()
|
||||
try:
|
||||
residue_num = int(line[22:26].strip())
|
||||
if chain_id not in chain_info_temp:
|
||||
chain_info_temp[chain_id] = {"residues": set(), "ca_count": 0}
|
||||
chain_info_temp[chain_id]["residues"].add(residue_num)
|
||||
if atom_name == "CA":
|
||||
chain_info_temp[chain_id]["ca_count"] += 1
|
||||
except ValueError:
|
||||
logger.warning(f"Could not parse residue number from PDB line: {line}")
|
||||
continue
|
||||
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
|
||||
header_lines.append(line)
|
||||
|
||||
chain_details: Dict[str, Dict[str, int]] = {}
|
||||
for chain_id, data in chain_info_temp.items():
|
||||
if data["residues"]:
|
||||
min_res = min(data["residues"])
|
||||
max_res = max(data["residues"])
|
||||
length = data["ca_count"] if data["ca_count"] > 0 else len(data["residues"])
|
||||
chain_details[chain_id] = {
|
||||
"min_residue": min_res,
|
||||
"max_residue": max_res,
|
||||
"length": length
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
|
||||
|
||||
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
|
||||
remaining_preview_lines = preview_lines - len(preview_str_parts)
|
||||
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
|
||||
pdb_preview = "\n".join(preview_str_parts)
|
||||
if len(pdb_content.splitlines()) > preview_lines:
|
||||
pdb_preview += "\n..."
|
||||
return chain_details, pdb_preview
|
||||
Loading…
Add table
Add a link
Reference in a new issue