This commit is contained in:
Shannon Sands 2025-05-27 12:15:15 +10:00
parent 13a70e09ab
commit 54967ecae9
19 changed files with 1337 additions and 531 deletions

View file

@ -2040,6 +2040,65 @@ python test_stl_env.py
--- ---
### 23. Protein Design Environment (`protein_design/`)
**Contributors**: hallerite, promachina
**PR**: [#70](https://github.com/NousResearch/atropos/pull/70)
**Integration Status**: ✅ Integrated
**Description**: A comprehensive reinforcement learning environment for de novo protein design through a staged simulation loop. This environment enables AI systems to learn the complete protein design workflow from target structure prediction to binder evaluation, using state-of-the-art protein modeling tools.
**Core Features**:
**Multi-Stage Protein Design Pipeline**:
- **AlphaFold2 Structure Prediction**: Predicts 3D structure of target proteins from amino acid sequences
- **RFDiffusion Backbone Generation**: Generates novel protein binder backbones conditioned on target structures
- **ProteinMPNN Sequence Design**: Designs optimal amino acid sequences for generated backbones
- **AlphaFold2-Multimer Evaluation**: Evaluates binding complex quality with pLDDT scoring
**Advanced Workflow Management**:
- **State-Based Progression**: Tracks workflow state through 4 distinct internal steps
- **Retry Logic**: Configurable retry mechanisms for failed tool executions
- **Validation Systems**: Comprehensive input validation for contigs, hotspots, and sequences
- **Error Handling**: Robust error recovery and detailed logging
**NVIDIA NIM Integration**:
- **API-Based Execution**: Leverages NVIDIA NIM APIs for protein modeling tools
- **Async Processing**: Non-blocking API calls with configurable timeouts and polling
- **Debug Mode**: Mock data generation for development and testing
- **Result Caching**: Saves intermediate PDB files and FASTA sequences
**Reward System**:
- **Format Rewards**: 0.2 points for correct tool usage in steps 0-2
- **Quality Rewards**: pLDDT-based scoring (0.0-1.0) for final complex evaluation
- **Progressive Scoring**: Cumulative rewards across workflow stages
**Data Management**:
- **Hugging Face Integration**: Loads protein binding datasets (ronig/protein_binding_sequences)
- **File Organization**: Structured output directory with timestamped results
- **Comprehensive Logging**: Detailed workflow tracking and performance metrics
**Research Applications**:
- **Drug Discovery**: Design novel protein binders for therapeutic targets
- **Protein Engineering**: Optimize protein-protein interactions
- **Structural Biology**: Explore protein design space systematically
- **AI Training**: Develop protein design capabilities in language models
**Technical Requirements**:
- NVIDIA NIM API access for protein modeling tools
- Python environment with protein analysis libraries
- Sufficient storage for PDB files and intermediate results
**Environment Configuration**:
- Configurable retry limits and timeout settings
- Debug mode for development without API calls
- Flexible dataset selection and column mapping
- WandB integration for experiment tracking
**Requirements**: pydantic, datasets, python-dotenv, pyyaml, wandb, atroposlib, nvidia-nim-api-client
---
## Support ## Support
For questions or issues with community environments: For questions or issues with community environments:

View file

@ -0,0 +1 @@
"""Protein design model API modules."""

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio import asyncio
from typing import Dict, List, Any, Optional import logging
from pathlib import Path from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2" DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_alphafold2( async def call_alphafold2(
sequence: str, sequence: str,
api_key: str, api_key: str,
@ -24,7 +23,7 @@ async def call_alphafold2(
status_url: str = DEFAULT_STATUS_URL, status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 600, timeout: int = 600,
max_retries: int = 3 max_retries: int = 3,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Call the NVIDIA NIM AlphaFold2 API. Call the NVIDIA NIM AlphaFold2 API.
@ -59,16 +58,13 @@ async def call_alphafold2(
"iterations": iterations, "iterations": iterations,
"databases": databases, "databases": databases,
"relax_prediction": relax_prediction, "relax_prediction": relax_prediction,
"skip_template_search": skip_template_search "skip_template_search": skip_template_search,
} }
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
url, url, json=data, headers=headers, timeout=timeout
json=data,
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json()
@ -81,7 +77,7 @@ async def call_alphafold2(
headers=headers, headers=headers,
status_url=status_url, status_url=status_url,
polling_interval=polling_interval, polling_interval=polling_interval,
timeout=timeout timeout=timeout,
) )
else: else:
logger.error("No request ID in response headers") logger.error("No request ID in response headers")
@ -93,16 +89,18 @@ async def call_alphafold2(
return None return None
except Exception as e: except Exception as e:
import traceback import traceback
logger.error(f"Error calling AlphaFold2 API: {e}") logger.error(f"Error calling AlphaFold2 API: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return None return None
async def _poll_job_status( async def _poll_job_status(
req_id: str, req_id: str,
headers: Dict[str, str], headers: Dict[str, str],
status_url: str, status_url: str,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 60 timeout: int = 60,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Poll the status endpoint until the job completes. Poll the status endpoint until the job completes.
@ -121,18 +119,20 @@ async def _poll_job_status(
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get( async with session.get(
f"{status_url}/{req_id}", f"{status_url}/{req_id}", headers=headers, timeout=timeout
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
logger.info(f"AlphaFold2 job {req_id} completed") logger.info(f"AlphaFold2 job {req_id} completed")
return await response.json() return await response.json()
elif response.status == 202: elif response.status == 202:
logger.debug(f"AlphaFold2 job {req_id} still running, polling...") logger.debug(
f"AlphaFold2 job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
else: else:
logger.error(f"Error checking AlphaFold2 job status: {response.status}") logger.error(
f"Error checking AlphaFold2 job status: {response.status}"
)
text = await response.text() text = await response.text()
logger.error(f"Response: {text}") logger.error(f"Response: {text}")
return None return None

View file

@ -1,16 +1,16 @@
import os
import logging
import aiohttp
import json
import asyncio import asyncio
from typing import Dict, List, Any, Optional, Tuple import json
from pathlib import Path import logging
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer" DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]: def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
""" """
Splits a string containing concatenated PDB file contents. Splits a string containing concatenated PDB file contents.
@ -35,7 +35,9 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
return [pdb for pdb in pdbs if pdb] return [pdb for pdb in pdbs if pdb]
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]: def calculate_plddt_from_pdb_string(
pdb_string: str,
) -> Tuple[float, List[float], Dict[str, List[float]]]:
total_plddt = 0.0 total_plddt = 0.0
ca_atom_count = 0 ca_atom_count = 0
plddt_scores_per_ca: List[float] = [] plddt_scores_per_ca: List[float] = []
@ -67,10 +69,11 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
average_plddt = total_plddt / ca_atom_count average_plddt = total_plddt / ca_atom_count
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_pdb_and_scores_from_api( async def _process_pdb_and_scores_from_api(
pdb_contents: List[str], pdb_contents: List[str],
job_id: str, job_id: str,
api_response_json: Optional[Dict[str, Any]] = None api_response_json: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Processes a list of PDB strings received from the API. Processes a list of PDB strings received from the API.
@ -78,14 +81,19 @@ async def _process_pdb_and_scores_from_api(
- Does NOT save files to disk. - Does NOT save files to disk.
- Returns a dictionary containing a list of structures, each with its PDB content and scores. - Returns a dictionary containing a list of structures, each with its PDB content and scores.
""" """
results: Dict[str, Any] = { results: Dict[str, Any] = {"structures": []}
"structures": []
}
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents): if (
not pdb_contents
or not isinstance(pdb_contents, list)
or not all(isinstance(s, str) for s in pdb_contents)
):
logger.warning(f"No valid PDB content strings provided for job {job_id}.") logger.warning(f"No valid PDB content strings provided for job {job_id}.")
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []} return {
"success": False,
"error": "No valid PDB content strings from API.",
"structures": [],
}
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.") logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
@ -94,12 +102,11 @@ async def _process_pdb_and_scores_from_api(
logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.") logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.")
continue continue
structure_data: Dict[str, Any] = { structure_data: Dict[str, Any] = {"model_index": i, "pdb_content": pdb_str}
"model_index": i,
"pdb_content": pdb_str
}
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str) avg_plddt, plddts_per_ca_residue, plddts_by_chain = (
calculate_plddt_from_pdb_string(pdb_str)
)
structure_data["average_plddt"] = avg_plddt structure_data["average_plddt"] = avg_plddt
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue
@ -116,13 +123,21 @@ async def _process_pdb_and_scores_from_api(
results["structures"].append(structure_data) results["structures"].append(structure_data)
if results["structures"]: if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.") logger.info(
f"Successfully processed and calculated pLDDTs for "
f"{len(results['structures'])} structures for job {job_id}."
)
else: else:
logger.warning(f"No structures were processed for job {job_id}.") logger.warning(f"No structures were processed for job {job_id}.")
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []} return {
"success": True,
"message": "No PDB structures found in API response to process.",
"structures": [],
}
return results return results
async def call_alphafold2_multimer( async def call_alphafold2_multimer(
sequences: List[str], sequences: List[str],
api_key: str, api_key: str,
@ -135,7 +150,7 @@ async def call_alphafold2_multimer(
url: str = DEFAULT_URL, url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL, status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30, polling_interval: int = 30,
timeout: int = 3600 timeout: int = 3600,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Call the NVIDIA NIM AlphaFold2-Multimer API. Call the NVIDIA NIM AlphaFold2-Multimer API.
@ -155,7 +170,7 @@ async def call_alphafold2_multimer(
"e_value": e_value, "e_value": e_value,
"iterations": iterations, "iterations": iterations,
"databases": databases, "databases": databases,
"relax_prediction": relax_prediction "relax_prediction": relax_prediction,
} }
if selected_models is not None: if selected_models is not None:
data["selected_models"] = selected_models data["selected_models"] = selected_models
@ -165,10 +180,7 @@ async def call_alphafold2_multimer(
initial_post_timeout = min(timeout, 600) initial_post_timeout = min(timeout, 600)
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
url, url, json=data, headers=headers, timeout=initial_post_timeout
json=data,
headers=headers,
timeout=initial_post_timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.") logger.info("AlphaFold2-Multimer job completed synchronously.")
@ -177,130 +189,239 @@ async def call_alphafold2_multimer(
if "application/json" in content_type: if "application/json" in content_type:
api_response_json_payload = await response.json() api_response_json_payload = await response.json()
if not isinstance(api_response_json_payload, list): if not isinstance(api_response_json_payload, list):
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload: if (
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}") isinstance(api_response_json_payload, dict)
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")} and "error" in api_response_json_payload
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."} ):
logger.error(
f"Sync API call returned error: "
f"{api_response_json_payload['error']}"
)
return {
"success": False,
"error": api_response_json_payload["error"],
"detail": api_response_json_payload.get(
"detail", ""
),
}
return {
"success": False,
"error": "Sync JSON response not a list of PDBs as expected.",
}
req_id_sync = response.headers.get("nvcf-reqid", "sync_job") req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
return await _process_pdb_and_scores_from_api( return await _process_pdb_and_scores_from_api(
pdb_contents=api_response_json_payload, pdb_contents=api_response_json_payload,
job_id=req_id_sync, job_id=req_id_sync,
api_response_json=None api_response_json=None,
) )
else: else:
err_text = await response.text() err_text = await response.text()
logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}") logger.error(
return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text} f"Sync response unexpected content type: {content_type}. "
f"Response: {err_text[:500]}"
)
return {
"success": False,
"error": f"Sync response unexpected content type: {content_type}",
"detail": err_text,
}
elif response.status == 202: elif response.status == 202:
req_id = response.headers.get("nvcf-reqid") req_id = response.headers.get("nvcf-reqid")
if req_id: if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}") logger.info(
f"AlphaFold2-Multimer job submitted, request ID: {req_id}"
)
return await _poll_job_status( return await _poll_job_status(
req_id=req_id, req_id=req_id,
headers=headers, headers=headers,
status_url=status_url, status_url=status_url,
polling_interval=polling_interval, polling_interval=polling_interval,
overall_timeout=timeout overall_timeout=timeout,
) )
else: else:
logger.error("No request ID in 202 response headers") logger.error("No request ID in 202 response headers")
return {"success": False, "error": "No request ID in 202 response headers"} return {
"success": False,
"error": "No request ID in 202 response headers",
}
else: else:
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}") logger.error(
f"Error calling AlphaFold2-Multimer API (POST): {response.status}"
)
text = await response.text() text = await response.text()
logger.error(f"Response: {text}") logger.error(f"Response: {text}")
return {"success": False, "error": f"Error calling API: {response.status}", "detail": text} return {
"success": False,
"error": f"Error calling API: {response.status}",
"detail": text,
}
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).") logger.error("Timeout during AlphaFold2-Multimer API (initial POST).")
return {"success": False, "error": "Timeout during initial API request"} return {"success": False, "error": "Timeout during initial API request"}
except Exception as e: except Exception as e:
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True) logger.error(
f"Exception during AlphaFold2-Multimer API call (initial POST): {e}",
exc_info=True,
)
return {"success": False, "error": f"Exception during API call: {str(e)}"} return {"success": False, "error": f"Exception during API call: {str(e)}"}
async def _poll_job_status( async def _poll_job_status(
req_id: str, req_id: str,
headers: Dict[str, str], headers: Dict[str, str],
status_url: str, status_url: str,
polling_interval: int = 30, polling_interval: int = 30,
overall_timeout: int = 3600 overall_timeout: int = 3600,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
per_status_request_timeout = 600 per_status_request_timeout = 600
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s") logger.info(
f"Polling job {req_id}. Individual status check timeout: "
f"{per_status_request_timeout}s, Polling interval: {polling_interval}s, "
f"Overall timeout: {overall_timeout}s"
)
while True: while True:
current_loop_time = asyncio.get_event_loop().time() current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time elapsed_time = current_loop_time - start_time
if elapsed_time >= overall_timeout: if elapsed_time >= overall_timeout:
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.") logger.error(
f"Overall polling timeout of {overall_timeout}s exceeded for "
f"job {req_id}."
)
return {"success": False, "error": "Overall polling timeout exceeded."} return {"success": False, "error": "Overall polling timeout exceeded."}
remaining_time_for_overall_timeout = overall_timeout - elapsed_time remaining_time_for_overall_timeout = overall_timeout - elapsed_time
current_status_check_timeout = min(per_status_request_timeout, remaining_time_for_overall_timeout) current_status_check_timeout = min(
per_status_request_timeout, remaining_time_for_overall_timeout
)
if current_status_check_timeout <= 0: if current_status_check_timeout <= 0:
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.") logger.error(
return {"success": False, "error": "Not enough time for status check within overall timeout."} f"Not enough time left for another status check for job {req_id} "
f"within overall_timeout."
)
return {
"success": False,
"error": "Not enough time for status check within overall timeout.",
}
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.") logger.debug(
f"Checking status for {req_id} with timeout "
f"{current_status_check_timeout}s."
)
async with session.get( async with session.get(
f"{status_url}/{req_id}", f"{status_url}/{req_id}",
headers=headers, headers=headers,
timeout=current_status_check_timeout timeout=current_status_check_timeout,
) as response: ) as response:
if response.status == 200: if response.status == 200:
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).") logger.info(
if response.content_type == 'application/json': f"AlphaFold2-Multimer job {req_id} completed (status 200)."
)
if response.content_type == "application/json":
try: try:
api_response_json_payload = await response.json() api_response_json_payload = await response.json()
if not isinstance(api_response_json_payload, list): if not isinstance(api_response_json_payload, list):
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload: if (
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}") isinstance(api_response_json_payload, dict)
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")} and "error" in api_response_json_payload
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.") ):
return {"success": False, "error": "API response was not a list of PDB strings."} logger.error(
f"Job {req_id}: API returned error: "
f"{api_response_json_payload['error']}"
)
return {
"success": False,
"error": api_response_json_payload["error"],
"detail": api_response_json_payload.get(
"detail", ""
),
}
logger.error(
f"Job {req_id}: Expected API response to be a list of PDB strings, "
f"got {type(api_response_json_payload)}."
)
return {
"success": False,
"error": "API response was not a list of PDB strings.",
}
return await _process_pdb_and_scores_from_api( return await _process_pdb_and_scores_from_api(
pdb_contents=api_response_json_payload, pdb_contents=api_response_json_payload,
job_id=req_id, job_id=req_id,
api_response_json=None api_response_json=None,
) )
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True) logger.error(
f"Job {req_id}: Failed to decode JSON response from API.",
exc_info=True,
)
raw_text = await response.text() raw_text = await response.text()
return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]} return {
"success": False,
"error": "Failed to decode JSON response.",
"detail": raw_text[:500],
}
else: else:
raw_text = await response.text() raw_text = await response.text()
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}") logger.error(
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text} f"Job {req_id}: Unexpected content type {response.content_type}. "
f"Expected application/json. Response: {raw_text[:500]}"
)
return {
"success": False,
"error": f"Unexpected content type: {response.content_type}",
"detail": raw_text,
}
elif response.status == 202: elif response.status == 202:
try: try:
job_status_json = await response.json() job_status_json = await response.json()
percent_complete = job_status_json.get('percentComplete', 'N/A') percent_complete = job_status_json.get(
status_message = job_status_json.get('status', 'running') "percentComplete", "N/A"
)
status_message = job_status_json.get("status", "running")
logger.debug( logger.debug(
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s." f"Job {req_id} status: {status_message} ({percent_complete}% complete). "
f"Polling again in {polling_interval}s."
) )
except (aiohttp.ContentTypeError, json.JSONDecodeError): except (aiohttp.ContentTypeError, json.JSONDecodeError):
logger.debug( logger.debug(
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s." f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). "
f"Polling again in {polling_interval}s."
) )
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
else: else:
text = await response.text() text = await response.text()
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: HTTP {response.status} - {text}") logger.error(
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text} f"Error checking AlphaFold2-Multimer job status {req_id}: "
f"HTTP {response.status} - {text}"
)
return {
"success": False,
"error": f"Status check failed with HTTP {response.status}",
"detail": text,
}
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.") logger.warning(
f"Client-side timeout ({current_status_check_timeout}s) during status check for "
f"job {req_id}. Retrying poll after {polling_interval}s sleep."
)
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True) logger.error(
f"Client error polling job status for {req_id}: {e}. "
f"Retrying poll after {polling_interval}s.",
exc_info=True,
)
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
except Exception as e: except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True) logger.error(
f"Unexpected error polling job status {req_id}: {e}", exc_info=True
)
return {"success": False, "error": f"Unexpected polling error: {str(e)}"} return {"success": False, "error": f"Unexpected polling error: {str(e)}"}

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio import asyncio
from typing import Dict, List, Any, Optional, Union import logging
from pathlib import Path from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict" DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_proteinmpnn( async def call_proteinmpnn(
input_pdb: str, input_pdb: str,
api_key: str, api_key: str,
@ -20,7 +19,7 @@ async def call_proteinmpnn(
url: str = DEFAULT_URL, url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL, status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 60 timeout: int = 60,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Call the NVIDIA NIM ProteinMPNN API. Call the NVIDIA NIM ProteinMPNN API.
@ -49,16 +48,13 @@ async def call_proteinmpnn(
"input_pdb": input_pdb, "input_pdb": input_pdb,
"ca_only": ca_only, "ca_only": ca_only,
"use_soluble_model": use_soluble_model, "use_soluble_model": use_soluble_model,
"sampling_temp": sampling_temp "sampling_temp": sampling_temp,
} }
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
url, url, json=data, headers=headers, timeout=timeout
json=data,
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json()
@ -71,7 +67,7 @@ async def call_proteinmpnn(
headers=headers, headers=headers,
status_url=status_url, status_url=status_url,
polling_interval=polling_interval, polling_interval=polling_interval,
timeout=timeout timeout=timeout,
) )
else: else:
logger.error("No request ID in response headers") logger.error("No request ID in response headers")
@ -85,12 +81,13 @@ async def call_proteinmpnn(
logger.error(f"Error calling ProteinMPNN API: {e}") logger.error(f"Error calling ProteinMPNN API: {e}")
return None return None
async def _poll_job_status( async def _poll_job_status(
req_id: str, req_id: str,
headers: Dict[str, str], headers: Dict[str, str],
status_url: str, status_url: str,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 60 timeout: int = 60,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Poll the status endpoint until the job completes. Poll the status endpoint until the job completes.
@ -109,18 +106,20 @@ async def _poll_job_status(
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get( async with session.get(
f"{status_url}/{req_id}", f"{status_url}/{req_id}", headers=headers, timeout=timeout
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
logger.info(f"ProteinMPNN job {req_id} completed") logger.info(f"ProteinMPNN job {req_id} completed")
return await response.json() return await response.json()
elif response.status == 202: elif response.status == 202:
logger.debug(f"ProteinMPNN job {req_id} still running, polling...") logger.debug(
f"ProteinMPNN job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
else: else:
logger.error(f"Error checking ProteinMPNN job status: {response.status}") logger.error(
f"Error checking ProteinMPNN job status: {response.status}"
)
text = await response.text() text = await response.text()
logger.error(f"Response: {text}") logger.error(f"Response: {text}")
return None return None

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio import asyncio
from typing import Dict, List, Any, Optional, Union import logging
from pathlib import Path from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate" DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_rfdiffusion( async def call_rfdiffusion(
input_pdb: str, input_pdb: str,
api_key: str, api_key: str,
@ -20,7 +19,7 @@ async def call_rfdiffusion(
url: str = DEFAULT_URL, url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL, status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 60 timeout: int = 60,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Call the NVIDIA NIM RFDiffusion API. Call the NVIDIA NIM RFDiffusion API.
@ -45,10 +44,7 @@ async def call_rfdiffusion(
"NVCF-POLL-SECONDS": "300", "NVCF-POLL-SECONDS": "300",
} }
data = { data = {"input_pdb": input_pdb, "diffusion_steps": diffusion_steps}
"input_pdb": input_pdb,
"diffusion_steps": diffusion_steps
}
if contigs: if contigs:
data["contigs"] = contigs data["contigs"] = contigs
@ -58,10 +54,7 @@ async def call_rfdiffusion(
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
url, url, json=data, headers=headers, timeout=timeout
json=data,
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json()
@ -74,7 +67,7 @@ async def call_rfdiffusion(
headers=headers, headers=headers,
status_url=status_url, status_url=status_url,
polling_interval=polling_interval, polling_interval=polling_interval,
timeout=timeout timeout=timeout,
) )
else: else:
logger.error("No request ID in response headers") logger.error("No request ID in response headers")
@ -88,12 +81,13 @@ async def call_rfdiffusion(
logger.error(f"Error calling RFDiffusion API: {e}") logger.error(f"Error calling RFDiffusion API: {e}")
return None return None
async def _poll_job_status( async def _poll_job_status(
req_id: str, req_id: str,
headers: Dict[str, str], headers: Dict[str, str],
status_url: str, status_url: str,
polling_interval: int = 10, polling_interval: int = 10,
timeout: int = 60 timeout: int = 60,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Poll the status endpoint until the job completes. Poll the status endpoint until the job completes.
@ -112,18 +106,20 @@ async def _poll_job_status(
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get( async with session.get(
f"{status_url}/{req_id}", f"{status_url}/{req_id}", headers=headers, timeout=timeout
headers=headers,
timeout=timeout
) as response: ) as response:
if response.status == 200: if response.status == 200:
logger.info(f"RFDiffusion job {req_id} completed") logger.info(f"RFDiffusion job {req_id} completed")
return await response.json() return await response.json()
elif response.status == 202: elif response.status == 202:
logger.debug(f"RFDiffusion job {req_id} still running, polling...") logger.debug(
f"RFDiffusion job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval) await asyncio.sleep(polling_interval)
else: else:
logger.error(f"Error checking RFDiffusion job status: {response.status}") logger.error(
f"Error checking RFDiffusion job status: {response.status}"
)
text = await response.text() text = await response.text()
logger.error(f"Response: {text}") logger.error(f"Response: {text}")
return None return None

View file

@ -1,25 +1,25 @@
import logging import logging
from typing import Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein. SYSTEM_PROMPT = (
"You are a specialized AI system for de novo protein design via a staged simulation loop. "
"Your objective is to generate binder sequences that are structurally and functionally "
"optimized to bind a given target protein.\n\n"
"You will be guided through a multi-step pipeline:\n\n"
"1. Structure prediction (AlphaFold)\n"
"2. Binder backbone generation (RFdiffusion)\n"
"3. Sequence design (ProteinMPNN)\n"
"4. Structure evaluation (AlphaFold-Multimer)\n"
"5. Feedback loop\n\n"
"You must always:\n"
"- Respect the required file format for each tool (e.g., FASTA, PDB).\n"
"- Structure your outputs cleanly so they can be parsed and executed programmatically.\n"
"- Be explicit in all configuration steps (e.g., contigs, hotspots).\n"
"- Minimize ambiguity or verbosity; prefer concise and functional outputs.\n"
"- Reason step-by-step when appropriate."
)
You will be guided through a multi-step pipeline:
1. Structure prediction (AlphaFold)
2. Binder backbone generation (RFdiffusion)
3. Sequence design (ProteinMPNN)
4. Structure evaluation (AlphaFold-Multimer)
5. Feedback loop
You must always:
- Respect the required file format for each tool (e.g., FASTA, PDB).
- Structure your outputs cleanly so they can be parsed and executed programmatically.
- Be explicit in all configuration steps (e.g., contigs, hotspots).
- Minimize ambiguity or verbosity; prefer concise and functional outputs.
- Reason step-by-step when appropriate.
"""
def construct_user_prompt(state: dict) -> str: def construct_user_prompt(state: dict) -> str:
""" """
@ -42,19 +42,20 @@ def construct_user_prompt(state: dict) -> str:
"You must provide the 'sequence' argument." "You must provide the 'sequence' argument."
) )
elif internal_step == 1: elif internal_step == 1:
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
chain_details = state.get("target_chain_details", {}) chain_details = state.get("target_chain_details", {})
if chain_details: if chain_details:
chain_info_parts = [] chain_info_parts = []
for chain_id, details in chain_details.items(): for chain_id, details in chain_details.items():
min_r = details.get('min_residue', 'N/A') min_r = details.get("min_residue", "N/A")
max_r = details.get('max_residue', 'N/A') max_r = details.get("max_residue", "N/A")
l = details.get('length', 'N/A') length = details.get("length", "N/A")
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)") chain_info_parts.append(
f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
)
chain_info_str = "\n- ".join(chain_info_parts) chain_info_str = "\n- ".join(chain_info_parts)
if chain_info_str: if chain_info_str:
chain_info_str = "- " + chain_info_str chain_info_str = "- " + chain_info_str
else: else:
chain_info_str = "Chain information not available or PDB not yet processed." chain_info_str = "Chain information not available or PDB not yet processed."
@ -62,19 +63,26 @@ def construct_user_prompt(state: dict) -> str:
f"The 3D structure of the target protein has been predicted.\n" f"The 3D structure of the target protein has been predicted.\n"
f"Target Protein Chain Details:\n{chain_info_str}\n\n" f"Target Protein Chain Details:\n{chain_info_str}\n\n"
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. " "Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. " "You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB "
"and segments for the new binder. "
"Examples:\n" "Examples:\n"
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n" " - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: "
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n" "'A10-100/0 60'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. " " - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B "
"from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the "
"'Target Protein Chain Details' above. "
"Do not invent chains or residue numbers outside these specified ranges for the target segments. " "Do not invent chains or residue numbers outside these specified ranges for the target segments. "
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n" "For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above." "Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist "
"on the target as per the details above."
) )
elif internal_step == 2: elif internal_step == 2:
binder_pdb_content = state.get("binder_backbone_pdb_content") binder_pdb_content = state.get("binder_backbone_pdb_content")
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.") binder_pdb_preview = state.get(
"binder_pdb_preview", "Binder PDB preview not available."
)
binder_chain_info_str = "Binder chain information not available." binder_chain_info_str = "Binder chain information not available."
if binder_pdb_content: if binder_pdb_content:
@ -83,10 +91,12 @@ def construct_user_prompt(state: dict) -> str:
if binder_chain_details: if binder_chain_details:
chain_info_parts = [] chain_info_parts = []
for cID, d_details in binder_chain_details.items(): for cID, d_details in binder_chain_details.items():
min_r = d_details.get('min_residue', 'N/A') min_r = d_details.get("min_residue", "N/A")
max_r = d_details.get('max_residue', 'N/A') max_r = d_details.get("max_residue", "N/A")
l = d_details.get('length', 'N/A') length = d_details.get("length", "N/A")
chain_info_parts.append(f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {l} amino acids)") chain_info_parts.append(
f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
)
binder_chain_info_str = "\n- ".join(chain_info_parts) binder_chain_info_str = "\n- ".join(chain_info_parts)
if binder_chain_info_str: if binder_chain_info_str:
binder_chain_info_str = "- " + binder_chain_info_str binder_chain_info_str = "- " + binder_chain_info_str
@ -99,7 +109,8 @@ def construct_user_prompt(state: dict) -> str:
user_prompt_str = ( user_prompt_str = (
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n" f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
f"Binder chain information:\n{binder_chain_info_str}.\n" f"Binder chain information:\n{binder_chain_info_str}.\n"
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. " "Now, design an optimal amino acid sequence for this binder backbone using the "
"'design_binder_sequence_proteinmpnn' tool. "
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])." "You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
) )
elif internal_step == 3: elif internal_step == 3:
@ -110,27 +121,39 @@ def construct_user_prompt(state: dict) -> str:
if len(designed_binder_seq_data) == 1: if len(designed_binder_seq_data) == 1:
binder_display_str = designed_binder_seq_data[0] binder_display_str = designed_binder_seq_data[0]
else: else:
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \ binder_display_str = (
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..." f"{len(designed_binder_seq_data)} chains: "
for i, s in enumerate(designed_binder_seq_data)]) + ", ".join(
[
f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
for i, s in enumerate(designed_binder_seq_data)
]
)
)
elif isinstance(designed_binder_seq_data, str): elif isinstance(designed_binder_seq_data, str):
binder_display_str = designed_binder_seq_data binder_display_str = designed_binder_seq_data
user_prompt_str = ( user_prompt_str = (
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. " f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
f"The original target sequence was: {target_sequence[:60]}...\n" f"The original target sequence was: {target_sequence[:60]}...\n"
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the " "Finally, evaluate the binding complex of the original target protein and ALL chains of this "
"'evaluate_binder_complex_alphafold2_multimer' tool. " "designed binder using the 'evaluate_binder_complex_alphafold2_multimer' tool. "
"You can optionally specify 'relax_prediction' (default is True)." "You can optionally specify 'relax_prediction' (default is True)."
) )
else: else:
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex." user_prompt_str = (
"The protein design workflow is complete. No further actions required by you for this item. "
"If successful, the key metric was the pLDDT of the complex."
)
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
retry_prefix = "Your previous attempt at this step was not successful. " retry_prefix = "Your previous attempt at this step was not successful. "
if state.get("previous_tool_error_message"): if state.get("previous_tool_error_message"):
retry_prefix += f"Details: {state['previous_tool_error_message']}. " retry_prefix += f"Details: {state['previous_tool_error_message']}. "
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n" retry_prefix += (
"Please review the requirements and PDB details carefully and try again to correctly use "
"the expected tool.\n\n"
)
user_prompt_str = retry_prefix + user_prompt_str user_prompt_str = retry_prefix + user_prompt_str
return user_prompt_str return user_prompt_str

View file

@ -0,0 +1,92 @@
PREDICT_TARGET_STRUCTURE_TOOL = {
"type": "function",
"function": {
"name": "predict_target_structure_alphafold2",
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
"parameters": {
"type": "object",
"properties": {
"sequence": {
"type": "string",
"description": "Amino acid sequence of the target protein.",
},
},
"required": ["sequence"],
},
},
}
DESIGN_BINDER_BACKBONE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_backbone_rfdiffusion",
"description": (
"Generates a novel protein binder backbone using RFDiffusion, "
"conditioned on the target protein structure."
),
"parameters": {
"type": "object",
"properties": {
"contigs": {
"type": "string",
"description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70').",
},
"hotspot_residues": {
"type": "array",
"items": {"type": "string"},
"description": "Optional hotspot residues (e.g., ['A50', 'A52']).",
},
},
"required": ["contigs"],
},
},
}
DESIGN_BINDER_SEQUENCE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_sequence_proteinmpnn",
"description": "Designs an amino acid sequence for the generated binder backbone.",
"parameters": {
"type": "object",
"properties": {
"sampling_temp": {
"type": "array",
"items": {"type": "number"},
"description": (
"Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."
),
}
},
"required": [],
},
},
}
EVALUATE_COMPLEX_TOOL = {
"type": "function",
"function": {
"name": "evaluate_binder_complex_alphafold2_multimer",
"description": (
"Predicts the complex structure of target and designed binder, "
"providing quality scores."
),
"parameters": {
"type": "object",
"properties": {
"relax_prediction": {
"type": "boolean",
"description": "Whether to relax the prediction. Default True.",
}
},
"required": [],
},
},
}
ALL_TOOLS_LIST = [
PREDICT_TARGET_STRUCTURE_TOOL,
DESIGN_BINDER_BACKBONE_TOOL,
DESIGN_BINDER_SEQUENCE_TOOL,
EVALUATE_COMPLEX_TOOL,
]

View file

@ -1,19 +1,26 @@
import logging import logging
import json
import re import re
from typing import Dict, Any, List, Tuple, Optional, Union
from pathlib import Path from pathlib import Path
from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2 from typing import Dict, List, Optional, Tuple
from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion
from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn from .models.alphafold2 import call_alphafold2
from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer from .models.alphafold2_multimer import call_alphafold2_multimer
from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details from .models.proteinmpnn import call_proteinmpnn
from .models.rfdiffusion import call_rfdiffusion
from .utils.pdb_utils import get_pdb_chain_details
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolExecutor: class ToolExecutor:
def __init__(self, nim_api_key: str, api_timeout: int, polling_interval: int, def __init__(
output_dir: Path, debug_protein_design_calls: bool): self,
nim_api_key: str,
api_timeout: int,
polling_interval: int,
output_dir: Path,
debug_protein_design_calls: bool,
):
self.nim_api_key = nim_api_key self.nim_api_key = nim_api_key
self.api_timeout = api_timeout self.api_timeout = api_timeout
self.polling_interval = polling_interval self.polling_interval = polling_interval
@ -21,18 +28,23 @@ class ToolExecutor:
self.debug_protein_design_calls = debug_protein_design_calls self.debug_protein_design_calls = debug_protein_design_calls
self._debug_af2m_call_count = 0 self._debug_af2m_call_count = 0
def _validate_rfd_contigs(self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]: def _validate_rfd_contigs(
self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]
) -> Optional[str]:
""" """
Validates the RFDiffusion contigs string against target PDB chain details. Validates the RFDiffusion contigs string against target PDB chain details.
Returns None if valid, or an error message string if invalid. Returns None if valid, or an error message string if invalid.
""" """
if not contigs_str: return "Contigs string is empty." if not contigs_str:
return "Contigs string is empty."
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)") target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
active_contig_parts = contigs_str.split('/') # Split by binder definition markers active_contig_parts = contigs_str.split(
"/"
) # Split by binder definition markers
for part in active_contig_parts: for part in active_contig_parts:
chain_segments_in_part = part.strip().split(' ') chain_segments_in_part = part.strip().split(" ")
for segment_text in chain_segments_in_part: for segment_text in chain_segments_in_part:
segment_text = segment_text.strip() segment_text = segment_text.strip()
if not segment_text or segment_text.isdigit(): if not segment_text or segment_text.isdigit():
@ -45,42 +57,61 @@ class ToolExecutor:
seg_end = int(seg_end_str) seg_end = int(seg_end_str)
if seg_chain_id not in target_chain_details: if seg_chain_id not in target_chain_details:
return f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. Valid: {list(target_chain_details.keys())}." return (
f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. "
f"Valid: {list(target_chain_details.keys())}."
)
chain_min = target_chain_details[seg_chain_id]["min_residue"] chain_min = target_chain_details[seg_chain_id]["min_residue"]
chain_max = target_chain_details[seg_chain_id]["max_residue"] chain_max = target_chain_details[seg_chain_id]["max_residue"]
if not (chain_min <= seg_start <= chain_max and chain_min <= seg_end <= chain_max and seg_start <= seg_end): if not (
return (f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' " chain_min <= seg_start <= chain_max
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}.") and chain_min <= seg_end <= chain_max
and seg_start <= seg_end
):
return (
f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}."
)
return None return None
def _validate_rfd_hotspots(self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]: def _validate_rfd_hotspots(
self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]
) -> Optional[str]:
""" """
Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details. Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details.
Returns None if valid, or an error message string if invalid. Returns None if valid, or an error message string if invalid.
""" """
if not hotspot_list: return None if not hotspot_list:
return None
hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)") hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)")
for hotspot_str in hotspot_list: for hotspot_str in hotspot_list:
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
if not match: if not match:
return f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')." return (
f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
)
hs_chain_id, hs_res_num_str = match.groups() hs_chain_id, hs_res_num_str = match.groups()
hs_res_num = int(hs_res_num_str) hs_res_num = int(hs_res_num_str)
if hs_chain_id not in target_chain_details: if hs_chain_id not in target_chain_details:
return f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. Valid: {list(target_chain_details.keys())}." return (
f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. "
f"Valid: {list(target_chain_details.keys())}."
)
chain_min = target_chain_details[hs_chain_id]["min_residue"] chain_min = target_chain_details[hs_chain_id]["min_residue"]
chain_max = target_chain_details[hs_chain_id]["max_residue"] chain_max = target_chain_details[hs_chain_id]["max_residue"]
if not (chain_min <= hs_res_num <= chain_max): if not (chain_min <= hs_res_num <= chain_max):
return (f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') " return (
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}.") f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}."
)
return None return None
async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict: async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict:
@ -96,13 +127,18 @@ class ToolExecutor:
state_updates = {} state_updates = {}
if self.debug_protein_design_calls: if self.debug_protein_design_calls:
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.") logger.warning(
f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}."
)
module_dir = Path(__file__).parent module_dir = Path(__file__).parent
fixed_pdb_path = module_dir / "debug_target.pdb" fixed_pdb_path = module_dir / "debug_target.pdb"
if not fixed_pdb_path.exists(): if not fixed_pdb_path.exists():
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.") logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
tool_output = {"success": False, "error": f"Debug mode failed: Required file {fixed_pdb_path} not found."} tool_output = {
"success": False,
"error": f"Debug mode failed: Required file {fixed_pdb_path} not found.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
with open(fixed_pdb_path, "r") as f: with open(fixed_pdb_path, "r") as f:
@ -115,7 +151,10 @@ class ToolExecutor:
state_updates["target_pdb_preview"] = pdb_preview state_updates["target_pdb_preview"] = pdb_preview
state_updates["target_structure_predicted"] = True state_updates["target_structure_predicted"] = True
debug_pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb" debug_pdb_path = (
self.output_dir
/ f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
)
with open(debug_pdb_path, "w") as f: with open(debug_pdb_path, "w") as f:
f.write(pdb_content) f.write(pdb_content)
logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}") logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}")
@ -124,25 +163,31 @@ class ToolExecutor:
"success": True, "success": True,
"message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.",
"target_pdb_preview": pdb_preview, "target_pdb_preview": pdb_preview,
"saved_pdb_path": str(debug_pdb_path) "saved_pdb_path": str(debug_pdb_path),
} }
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
sequence_from_llm = args.get("sequence") sequence_from_llm = args.get("sequence")
if not sequence_from_llm: if not sequence_from_llm:
tool_output = {"success": False, "error": "Missing 'sequence' for AlphaFold2."} tool_output = {
"success": False,
"error": "Missing 'sequence' for AlphaFold2.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
actual_sequence_to_use = target_sequence_from_state actual_sequence_to_use = target_sequence_from_state
if sequence_from_llm != target_sequence_from_state: if sequence_from_llm != target_sequence_from_state:
logger.warning( logger.warning(
f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. " f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. "
f"However, this tool will use the canonical target sequence from the workflow state: '{target_sequence_from_state[:20]}...'" f"However, this tool will use the canonical target sequence from the workflow state: "
f"'{target_sequence_from_state[:20]}...'"
) )
api_result = await call_alphafold2( api_result = await call_alphafold2(
sequence=actual_sequence_to_use, api_key=self.nim_api_key, sequence=actual_sequence_to_use,
timeout=self.api_timeout, polling_interval=self.polling_interval api_key=self.nim_api_key,
timeout=self.api_timeout,
polling_interval=self.polling_interval,
) )
if api_result and isinstance(api_result, list) and api_result[0]: if api_result and isinstance(api_result, list) and api_result[0]:
pdb_content = api_result[0] pdb_content = api_result[0]
@ -153,20 +198,34 @@ class ToolExecutor:
state_updates["target_pdb_preview"] = pdb_preview state_updates["target_pdb_preview"] = pdb_preview
state_updates["target_structure_predicted"] = True state_updates["target_structure_predicted"] = True
pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb" pdb_path = (
with open(pdb_path, "w") as f: f.write(pdb_content) self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}") )
with open(pdb_path, "w") as f:
f.write(pdb_content)
logger.info(
f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. "
f"Chain details: {chain_details}"
)
tool_output = {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview, "saved_pdb_path": str(pdb_path)} tool_output = {
"success": True,
"message": "AlphaFold2 prediction complete.",
"target_pdb_preview": pdb_preview,
"saved_pdb_path": str(pdb_path),
}
else: else:
error_detail = api_result.get("error", "AlphaFold2 prediction failed.") if isinstance(api_result, dict) else "AlphaFold2 prediction failed." error_detail = (
api_result.get("error", "AlphaFold2 prediction failed.")
if isinstance(api_result, dict)
else "AlphaFold2 prediction failed."
)
logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}") logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}")
tool_output = {"success": False, "error": error_detail} tool_output = {"success": False, "error": error_detail}
state_updates["target_structure_predicted"] = False state_updates["target_structure_predicted"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict: async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict:
""" """
Runs RFDiffusion for binder backbone design. Returns structured output with Runs RFDiffusion for binder backbone design. Returns structured output with
@ -182,30 +241,55 @@ class ToolExecutor:
contigs_str_from_llm = args.get("contigs") contigs_str_from_llm = args.get("contigs")
if not target_pdb_content: if not target_pdb_content:
tool_output = {"success": False, "error": "Target PDB not found in state for RFDiffusion."} tool_output = {
"success": False,
"error": "Target PDB not found in state for RFDiffusion.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
if not contigs_str_from_llm: if not contigs_str_from_llm:
tool_output = {"success": False, "error": "Missing 'contigs' for RFDiffusion."} tool_output = {
"success": False,
"error": "Missing 'contigs' for RFDiffusion.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
validation_error = self._validate_rfd_contigs(contigs_str_from_llm, target_chain_details) validation_error = self._validate_rfd_contigs(
contigs_str_from_llm, target_chain_details
)
if validation_error: if validation_error:
logger.warning(f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. Contigs: '{contigs_str_from_llm}'") logger.warning(
tool_output = {"success": False, "error": f"Invalid contigs: {validation_error}"} f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. "
f"Contigs: '{contigs_str_from_llm}'"
)
tool_output = {
"success": False,
"error": f"Invalid contigs: {validation_error}",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
hotspot_residues = args.get("hotspot_residues") hotspot_residues = args.get("hotspot_residues")
if hotspot_residues: if hotspot_residues:
hotspot_validation_error = self._validate_rfd_hotspots(hotspot_residues, target_chain_details) hotspot_validation_error = self._validate_rfd_hotspots(
hotspot_residues, target_chain_details
)
if hotspot_validation_error: if hotspot_validation_error:
logger.warning(f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. Hotspots: {hotspot_residues}") logger.warning(
tool_output = {"success": False, "error": f"Invalid hotspots: {hotspot_validation_error}"} f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. "
f"Hotspots: {hotspot_residues}"
)
tool_output = {
"success": False,
"error": f"Invalid hotspots: {hotspot_validation_error}",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
api_result = await call_rfdiffusion( api_result = await call_rfdiffusion(
input_pdb=target_pdb_content, api_key=self.nim_api_key, input_pdb=target_pdb_content,
contigs=contigs_str_from_llm, hotspot_res=hotspot_residues, api_key=self.nim_api_key,
timeout=self.api_timeout, polling_interval=self.polling_interval contigs=contigs_str_from_llm,
hotspot_res=hotspot_residues,
timeout=self.api_timeout,
polling_interval=self.polling_interval,
) )
if api_result and "output_pdb" in api_result: if api_result and "output_pdb" in api_result:
@ -217,20 +301,34 @@ class ToolExecutor:
state_updates["binder_pdb_preview"] = binder_pdb_preview state_updates["binder_pdb_preview"] = binder_pdb_preview
state_updates["binder_backbone_designed"] = True state_updates["binder_backbone_designed"] = True
pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb" pdb_path = (
with open(pdb_path, "w") as f: f.write(binder_pdb) self.output_dir
/ f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
)
with open(pdb_path, "w") as f:
f.write(binder_pdb)
logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}") logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}")
tool_output = {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb_preview, "saved_pdb_path": str(pdb_path)} tool_output = {
"success": True,
"message": "RFDiffusion complete.",
"binder_backbone_pdb_preview": binder_pdb_preview,
"saved_pdb_path": str(pdb_path),
}
else: else:
error_detail = api_result.get("error", "RFDiffusion failed.") if isinstance(api_result, dict) else "RFDiffusion failed." error_detail = (
logger.error(f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}") api_result.get("error", "RFDiffusion failed.")
if isinstance(api_result, dict)
else "RFDiffusion failed."
)
logger.error(
f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail} tool_output = {"success": False, "error": error_detail}
state_updates["binder_backbone_designed"] = False state_updates["binder_backbone_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict: async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
""" """
Runs ProteinMPNN for binder sequence design. Returns structured output with Runs ProteinMPNN for binder sequence design. Returns structured output with
@ -244,20 +342,33 @@ class ToolExecutor:
state_updates = {} state_updates = {}
if not binder_pdb: if not binder_pdb:
tool_output = {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."} tool_output = {
"success": False,
"error": "Binder backbone PDB not found for ProteinMPNN.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
sampling_temp_list = args.get("sampling_temp", [0.1]) sampling_temp_list = args.get("sampling_temp", [0.1])
api_result = await call_proteinmpnn( api_result = await call_proteinmpnn(
input_pdb=binder_pdb, api_key=self.nim_api_key, input_pdb=binder_pdb,
api_key=self.nim_api_key,
sampling_temp=sampling_temp_list, sampling_temp=sampling_temp_list,
timeout=self.api_timeout, polling_interval=self.polling_interval timeout=self.api_timeout,
polling_interval=self.polling_interval,
) )
if not (api_result and "mfasta" in api_result): if not (api_result and "mfasta" in api_result):
error_detail = api_result.get("error", "ProteinMPNN call failed or no mfasta in result.") if isinstance(api_result, dict) else "PMPNN call failed" error_detail = (
logger.error(f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}") api_result.get(
"error", "ProteinMPNN call failed or no mfasta in result."
)
if isinstance(api_result, dict)
else "PMPNN call failed"
)
logger.error(
f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail} tool_output = {"success": False, "error": error_detail}
state_updates["binder_sequence_designed"] = False state_updates["binder_sequence_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
@ -268,12 +379,15 @@ class ToolExecutor:
current_sequence_parts: List[str] = [] current_sequence_parts: List[str] = []
for line_content in fasta_content.splitlines(): for line_content in fasta_content.splitlines():
line = line_content.strip() line = line_content.strip()
if not line: continue if not line:
continue
if line.startswith(">"): if line.startswith(">"):
if current_header and current_sequence_parts: if current_header and current_sequence_parts:
full_sequence_line = "".join(current_sequence_parts) full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header) score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf') global_score = (
float(score_match.group(1)) if score_match else -float("inf")
)
entries.append((global_score, current_header, full_sequence_line)) entries.append((global_score, current_header, full_sequence_line))
current_header = line current_header = line
current_sequence_parts = [] current_sequence_parts = []
@ -282,7 +396,7 @@ class ToolExecutor:
if current_header and current_sequence_parts: if current_header and current_sequence_parts:
full_sequence_line = "".join(current_sequence_parts) full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header) score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf') global_score = float(score_match.group(1)) if score_match else -float("inf")
entries.append((global_score, current_header, full_sequence_line)) entries.append((global_score, current_header, full_sequence_line))
if not entries: if not entries:
@ -292,36 +406,59 @@ class ToolExecutor:
entries.sort(key=lambda x: x[0], reverse=True) entries.sort(key=lambda x: x[0], reverse=True)
best_global_score, best_header, best_full_sequence_line = entries[0] best_global_score, best_header, best_full_sequence_line = entries[0]
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'") logger.info(
f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) "
f"from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'"
)
parsed_binder_chains = [s.strip() for s in best_full_sequence_line.split('/') if s.strip()] parsed_binder_chains = [
s.strip() for s in best_full_sequence_line.split("/") if s.strip()
]
if not parsed_binder_chains or not all(s and s.isalpha() and s.isupper() for s in parsed_binder_chains): if not parsed_binder_chains or not all(
tool_output = {"success": False, "error": f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. Parsed: {parsed_binder_chains}"} s and s.isalpha() and s.isupper() for s in parsed_binder_chains
):
tool_output = {
"success": False,
"error": (
f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. "
f"Parsed: {parsed_binder_chains}"
),
}
state_updates["binder_sequence_designed"] = False state_updates["binder_sequence_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
state_updates["designed_binder_sequence"] = parsed_binder_chains state_updates["designed_binder_sequence"] = parsed_binder_chains
state_updates["binder_sequence_designed"] = True state_updates["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta" fasta_path = (
with open(fasta_path, "w") as f: f.write(fasta_content) self.output_dir
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}") / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
)
with open(fasta_path, "w") as f:
f.write(fasta_content)
logger.info(
f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. "
f"Selected binder chains: {parsed_binder_chains}"
)
preview = parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A" preview = (
parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
)
if len(parsed_binder_chains) > 1: if len(parsed_binder_chains) > 1:
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))" preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
tool_output = { tool_output = {
"success": True, "success": True,
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).", "message": (
f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f})."
),
"designed_binder_sequence_list": parsed_binder_chains, "designed_binder_sequence_list": parsed_binder_chains,
"designed_binder_sequence_preview": preview, "designed_binder_sequence_preview": preview,
"saved_fasta_path": str(fasta_path) "saved_fasta_path": str(fasta_path),
} }
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict: async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
item_id = workflow_state["item_id"] item_id = workflow_state["item_id"]
current_internal_step = workflow_state["current_internal_step"] current_internal_step = workflow_state["current_internal_step"]
@ -331,16 +468,31 @@ class ToolExecutor:
tool_output = {} tool_output = {}
state_updates = {} state_updates = {}
if not target_seq or not designed_binder_chains_list or not isinstance(designed_binder_chains_list, list): if (
tool_output = {"success": False, "error": "Missing or invalid sequences for AF2-Multimer."} not target_seq
or not designed_binder_chains_list
or not isinstance(designed_binder_chains_list, list)
):
tool_output = {
"success": False,
"error": "Missing or invalid sequences for AF2-Multimer.",
}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list
for i, seq_to_validate in enumerate(all_input_sequences_for_multimer): for i, seq_to_validate in enumerate(all_input_sequences_for_multimer):
if not (seq_to_validate and isinstance(seq_to_validate, str) and seq_to_validate.isalpha() and seq_to_validate.isupper()): if not (
error_msg = (f"Sequence {i+1} (part of target/binder complex) is invalid: " seq_to_validate
f"'{str(seq_to_validate)[:30]}...'. Contains non-alpha/lowercase, is empty, or not a string.") and isinstance(seq_to_validate, str)
and seq_to_validate.isalpha()
and seq_to_validate.isupper()
):
error_msg = (
f"Sequence {i+1} (part of target/binder complex) is invalid: "
f"'{str(seq_to_validate)[:30]}...'. "
f"Contains non-alpha/lowercase, is empty, or not a string."
)
logger.error(f"Workflow {item_id}: {error_msg}") logger.error(f"Workflow {item_id}: {error_msg}")
tool_output = {"success": False, "error": error_msg} tool_output = {"success": False, "error": error_msg}
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
@ -350,28 +502,41 @@ class ToolExecutor:
if self.debug_protein_design_calls: if self.debug_protein_design_calls:
self._debug_af2m_call_count += 1 self._debug_af2m_call_count += 1
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2 mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})" success_message = (
f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality "
f"mock results (call #{self._debug_af2m_call_count})"
)
debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb" debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb"
debug_pdb_path = self.output_dir / debug_pdb_filename debug_pdb_path = self.output_dir / debug_pdb_filename
try: try:
with open(debug_pdb_path, "w") as f: with open(debug_pdb_path, "w") as f:
f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n") f.write(
logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}") f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n"
)
logger.info(
f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}"
)
state_updates["complex_pdb_content_path"] = str(debug_pdb_path) state_updates["complex_pdb_content_path"] = str(debug_pdb_path)
except IOError as e: except IOError as e:
logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}") logger.error(
f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}"
)
# If saving fails, don't set the path, but can still proceed with mock pLDDT # If saving fails, don't set the path, but can still proceed with mock pLDDT
state_updates["complex_pdb_content_path"] = None state_updates["complex_pdb_content_path"] = None
state_updates["af2_multimer_plddt"] = mock_plddt state_updates["af2_multimer_plddt"] = mock_plddt
state_updates["complex_evaluated"] = True state_updates["complex_evaluated"] = True
tool_output = { tool_output = {
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}", "success": True,
"message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
"plddt": mock_plddt, "plddt": mock_plddt,
"complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None "complex_file_path": (
str(debug_pdb_path)
if state_updates["complex_pdb_content_path"]
else None
),
} }
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
@ -380,36 +545,57 @@ class ToolExecutor:
api_key=self.nim_api_key, api_key=self.nim_api_key,
relax_prediction=relax, relax_prediction=relax,
timeout=self.api_timeout, timeout=self.api_timeout,
polling_interval=self.polling_interval polling_interval=self.polling_interval,
) )
if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False): if api_result is None or (
isinstance(api_result, dict) and api_result.get("success") is False
):
error_detail = "AF2-Multimer call failed or returned None." error_detail = "AF2-Multimer call failed or returned None."
if isinstance(api_result, dict): if isinstance(api_result, dict):
error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.") error_detail = api_result.get(
"error", "AF2-Multimer call failed with unspecified error."
)
detail_info = api_result.get("detail", "") detail_info = api_result.get("detail", "")
if detail_info: error_detail += f" Details: {detail_info}" if detail_info:
error_detail += f" Details: {detail_info}"
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}") logger.error(
f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. "
f"API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail} tool_output = {"success": False, "error": error_detail}
state_updates["complex_evaluated"] = False state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
all_structures_info = api_result.get("structures") all_structures_info = api_result.get("structures")
if not all_structures_info or not isinstance(all_structures_info, list): if not all_structures_info or not isinstance(all_structures_info, list):
message = api_result.get("message", "No structures returned from AF2-Multimer process.") message = api_result.get(
"message", "No structures returned from AF2-Multimer process."
)
logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}") logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}")
if not all_structures_info and isinstance(all_structures_info, list): if not all_structures_info and isinstance(all_structures_info, list):
tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None} tool_output = {
state_updates["af2_multimer_plddt"] = 0.0 "success": True,
state_updates["complex_evaluated"] = True "message": (
state_updates["complex_pdb_content_path"] = None "AF2-Multimer ran, but no PDB structures were produced by the API."
),
"plddt": 0.0,
"complex_file_path": None,
}
state_updates["af2_multimer_plddt"] = 0.0
state_updates["complex_evaluated"] = True
state_updates["complex_pdb_content_path"] = None
else: else:
tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."} tool_output = {
"success": False,
"error": (
"AF2-Multimer returned unexpected data or no structures."
),
}
state_updates["complex_evaluated"] = False state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
best_structure_info = None best_structure_info = None
highest_plddt = -1.0 highest_plddt = -1.0
@ -420,8 +606,14 @@ class ToolExecutor:
best_structure_info = struct_info best_structure_info = struct_info
if best_structure_info is None: if best_structure_info is None:
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.") logger.error(
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."} f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, "
f"though structures were present."
)
tool_output = {
"success": False,
"error": ("No valid structure with pLDDT in AF2-Multimer results."),
}
state_updates["complex_evaluated"] = False state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
@ -430,53 +622,81 @@ class ToolExecutor:
best_model_idx = best_structure_info.get("model_index", "NA") best_model_idx = best_structure_info.get("model_index", "NA")
if not best_pdb_content: if not best_pdb_content:
logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.") logger.error(
tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."} f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, "
f"pLDDT {best_plddt:.2f}) found, but PDB content is missing."
)
tool_output = {
"success": False,
"error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content.",
}
state_updates["complex_evaluated"] = False state_updates["complex_evaluated"] = False
state_updates["af2_multimer_plddt"] = best_plddt state_updates["af2_multimer_plddt"] = best_plddt
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb" complex_pdb_filename = (
f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_"
f"pLDDT{best_plddt:.2f}.pdb"
)
complex_pdb_path = self.output_dir / complex_pdb_filename complex_pdb_path = self.output_dir / complex_pdb_filename
try: try:
with open(complex_pdb_path, "w", encoding='utf-8') as f: with open(complex_pdb_path, "w", encoding="utf-8") as f:
f.write(best_pdb_content) f.write(best_pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}") logger.info(
f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) "
f"with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}"
)
state_updates["complex_pdb_content_path"] = str(complex_pdb_path) state_updates["complex_pdb_content_path"] = str(complex_pdb_path)
state_updates["af2_multimer_plddt"] = best_plddt state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_evaluated"] = True state_updates["complex_evaluated"] = True
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}" complex_quality_message = (
f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) "
f"with pLDDT: {best_plddt:.2f}"
)
tool_output = { tool_output = {
"success": True, "success": True,
"message": complex_quality_message, "message": complex_quality_message,
"plddt": best_plddt, "plddt": best_plddt,
"complex_file_path": str(complex_pdb_path), "complex_file_path": str(complex_pdb_path),
"selected_model_index": best_model_idx "selected_model_index": best_model_idx,
} }
except IOError as e: except IOError as e:
logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}") logger.error(
tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"} f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, "
f"pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}"
)
tool_output = {
"success": False,
"error": f"Failed to save best complex PDB: {e}",
}
state_updates["af2_multimer_plddt"] = best_plddt state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_pdb_content_path"] = None state_updates["complex_pdb_content_path"] = None
state_updates["complex_evaluated"] = True state_updates["complex_evaluated"] = True
return {"tool_output": tool_output, "state_updates": state_updates} return {"tool_output": tool_output, "state_updates": state_updates}
async def dispatch_tool_call(
async def dispatch_tool_call(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict: self, tool_name: str, args: Dict, workflow_state: Dict
) -> Dict:
"""Main dispatch method for executing tools.""" """Main dispatch method for executing tools."""
item_id = workflow_state["item_id"] item_id = workflow_state["item_id"]
internal_step = workflow_state["current_internal_step"] internal_step = workflow_state["current_internal_step"]
logger.info(f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, Step {internal_step} with args: {args}") logger.info(
f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, "
f"Step {internal_step} with args: {args}"
)
if not self.nim_api_key: if not self.nim_api_key:
return { return {
"tool_output": {"success": False, "error": "NIM API key not configured in ToolExecutor."}, "tool_output": {
"state_updates": {} "success": False,
"error": "NIM API key not configured in ToolExecutor.",
},
"state_updates": {},
} }
if tool_name == "predict_target_structure_alphafold2": if tool_name == "predict_target_structure_alphafold2":
@ -488,8 +708,13 @@ class ToolExecutor:
elif tool_name == "evaluate_binder_complex_alphafold2_multimer": elif tool_name == "evaluate_binder_complex_alphafold2_multimer":
return await self._run_nim_af2_multimer(args, workflow_state) return await self._run_nim_af2_multimer(args, workflow_state)
else: else:
logger.error(f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}") logger.error(
f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}"
)
return { return {
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"}, "tool_output": {
"state_updates": {} "success": False,
"error": f"Unknown tool name: {tool_name}",
},
"state_updates": {},
} }

View file

@ -1,13 +1,13 @@
import os
import logging import logging
import yaml import os
from pathlib import Path
from typing import Optional from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_api_key() -> Optional[str]: def load_api_key() -> Optional[str]:
""" """
Load the NVIDIA NIM API key from environment variables. Load the NVIDIA NIM API key from environment variables.
@ -17,8 +17,10 @@ def load_api_key() -> Optional[str]:
""" """
api_key = os.environ.get("NVIDIA_NIM_API_KEY") api_key = os.environ.get("NVIDIA_NIM_API_KEY")
if not api_key: if not api_key:
logger.error("NVIDIA_NIM_API_KEY not found in environment variables. " logger.error(
"Please set it in your .env file.") "NVIDIA_NIM_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
return None return None
return api_key return api_key

View file

@ -1,9 +1,12 @@
import logging import logging
from typing import Dict, Tuple, List, Set, Union from typing import Dict, Set, Tuple, Union
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
def get_pdb_chain_details(
pdb_content: str, preview_lines: int = 10
) -> Tuple[Dict[str, Dict[str, int]], str]:
""" """
Parses PDB content to extract detailed information for each chain. Parses PDB content to extract detailed information for each chain.
@ -27,7 +30,8 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
if line.startswith("ATOM"): if line.startswith("ATOM"):
atom_lines.append(line) atom_lines.append(line)
chain_id = line[21:22].strip() chain_id = line[21:22].strip()
if not chain_id: chain_id = " " if not chain_id:
chain_id = " "
atom_name = line[12:16].strip() atom_name = line[12:16].strip()
try: try:
residue_num = int(line[22:26].strip()) residue_num = int(line[22:26].strip())
@ -39,7 +43,11 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
except ValueError: except ValueError:
logger.warning(f"Could not parse residue number from PDB line: {line}") logger.warning(f"Could not parse residue number from PDB line: {line}")
continue continue
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"): elif (
line.startswith("HEADER")
or line.startswith("TITLE")
or line.startswith("COMPND")
):
header_lines.append(line) header_lines.append(line)
chain_details: Dict[str, Dict[str, int]] = {} chain_details: Dict[str, Dict[str, int]] = {}
@ -51,14 +59,16 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
chain_details[chain_id] = { chain_details[chain_id] = {
"min_residue": min_res, "min_residue": min_res,
"max_residue": max_res, "max_residue": max_res,
"length": length "length": length,
} }
else: else:
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.") logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)] preview_str_parts = header_lines[: min(len(header_lines), preview_lines // 2)]
remaining_preview_lines = preview_lines - len(preview_str_parts) remaining_preview_lines = preview_lines - len(preview_str_parts)
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)]) preview_str_parts.extend(
atom_lines[: min(len(atom_lines), remaining_preview_lines)]
)
pdb_preview = "\n".join(preview_str_parts) pdb_preview = "\n".join(preview_str_parts)
if len(pdb_content.splitlines()) > preview_lines: if len(pdb_content.splitlines()) > preview_lines:
pdb_preview += "\n..." pdb_preview += "\n..."

View file

@ -1 +0,0 @@
"""Protein design model API modules."""

View file

@ -1,67 +0,0 @@
PREDICT_TARGET_STRUCTURE_TOOL = {
"type": "function",
"function": {
"name": "predict_target_structure_alphafold2",
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
"parameters": {
"type": "object",
"properties": {
"sequence": {"type": "string", "description": "Amino acid sequence of the target protein."},
},
"required": ["sequence"]
}
}
}
DESIGN_BINDER_BACKBONE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_backbone_rfdiffusion",
"description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.",
"parameters": {
"type": "object",
"properties": {
"contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."},
"hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."},
},
"required": ["contigs"]
}
}
}
DESIGN_BINDER_SEQUENCE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_sequence_proteinmpnn",
"description": "Designs an amino acid sequence for the generated binder backbone.",
"parameters": {
"type": "object",
"properties": {
"sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."}
},
"required": []
}
}
}
EVALUATE_COMPLEX_TOOL = {
"type": "function",
"function": {
"name": "evaluate_binder_complex_alphafold2_multimer",
"description": "Predicts the complex structure of target and designed binder, providing quality scores.",
"parameters": {
"type": "object",
"properties": {
"relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."}
},
"required": []
}
}
}
ALL_TOOLS_LIST = [
PREDICT_TARGET_STRUCTURE_TOOL,
DESIGN_BINDER_BACKBONE_TOOL,
DESIGN_BINDER_SEQUENCE_TOOL,
EVALUATE_COMPLEX_TOOL
]