This commit is contained in:
Shannon Sands 2025-05-27 12:15:15 +10:00
parent 13a70e09ab
commit 54967ecae9
19 changed files with 1337 additions and 531 deletions

View file

@ -2040,6 +2040,65 @@ python test_stl_env.py
---
### 23. Protein Design Environment (`protein_design/`)
**Contributors**: hallerite, promachina
**PR**: [#70](https://github.com/NousResearch/atropos/pull/70)
**Integration Status**: ✅ Integrated
**Description**: A comprehensive reinforcement learning environment for de novo protein design through a staged simulation loop. This environment enables AI systems to learn the complete protein design workflow from target structure prediction to binder evaluation, using state-of-the-art protein modeling tools.
**Core Features**:
**Multi-Stage Protein Design Pipeline**:
- **AlphaFold2 Structure Prediction**: Predicts 3D structure of target proteins from amino acid sequences
- **RFDiffusion Backbone Generation**: Generates novel protein binder backbones conditioned on target structures
- **ProteinMPNN Sequence Design**: Designs optimal amino acid sequences for generated backbones
- **AlphaFold2-Multimer Evaluation**: Evaluates binding complex quality with pLDDT scoring
**Advanced Workflow Management**:
- **State-Based Progression**: Tracks workflow state through 4 distinct internal steps
- **Retry Logic**: Configurable retry mechanisms for failed tool executions
- **Validation Systems**: Comprehensive input validation for contigs, hotspots, and sequences
- **Error Handling**: Robust error recovery and detailed logging
**NVIDIA NIM Integration**:
- **API-Based Execution**: Leverages NVIDIA NIM APIs for protein modeling tools
- **Async Processing**: Non-blocking API calls with configurable timeouts and polling
- **Debug Mode**: Mock data generation for development and testing
- **Result Caching**: Saves intermediate PDB files and FASTA sequences
**Reward System**:
- **Format Rewards**: 0.2 points for correct tool usage in steps 0-2
- **Quality Rewards**: pLDDT-based scoring (0.0-1.0) for final complex evaluation
- **Progressive Scoring**: Cumulative rewards across workflow stages
**Data Management**:
- **Hugging Face Integration**: Loads protein binding datasets (ronig/protein_binding_sequences)
- **File Organization**: Structured output directory with timestamped results
- **Comprehensive Logging**: Detailed workflow tracking and performance metrics
**Research Applications**:
- **Drug Discovery**: Design novel protein binders for therapeutic targets
- **Protein Engineering**: Optimize protein-protein interactions
- **Structural Biology**: Explore protein design space systematically
- **AI Training**: Develop protein design capabilities in language models
**Technical Requirements**:
- NVIDIA NIM API access for protein modeling tools
- Python environment with protein analysis libraries
- Sufficient storage for PDB files and intermediate results
**Environment Configuration**:
- Configurable retry limits and timeout settings
- Debug mode for development without API calls
- Flexible dataset selection and column mapping
- WandB integration for experiment tracking
**Requirements**: pydantic, datasets, python-dotenv, pyyaml, wandb, atroposlib, nvidia-nim-api-client
---
## Support
For questions or issues with community environments:

View file

@ -1,3 +1,3 @@
# We use NVIDIA NIM to access hosted models on the API
NVIDIA_NIM_API_KEY: "YOUR API KEY"
NVIDIA_NIM_API_KEY: "YOUR API KEY"

View file

@ -38,7 +38,7 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state
### Step 4: Evaluate Binding (`AlphaFold-Multimer`)
- **Input:** Target + binder sequences
- **Output:** Complex structure prediction
- **Reward:**
- **Reward:**
- Format OK
- No steric clashes
- **Bonus:** Contact interface, binding affinity metrics (Not yet implemented)
@ -48,9 +48,9 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state
## 🏆 Reward Function
The reward is cumulative:
- **+0.2**: Successfully generate output in correct format at each step
- **+0.2**: Successfully generate output in correct format at each step
- **+0.0 to +1.0:** Structural reward based on complex validity smoothly interpolated on AlphaFold2 multimere confidence
- **+1**: High predicted binding affinity (Not yet implemented)
- **+1**: High predicted binding affinity (Not yet implemented)
Sparse, but real. LLMs must *plan* tool use, not just spam actions.

View file

@ -0,0 +1 @@
"""Protein design model API modules."""

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_alphafold2(
sequence: str,
api_key: str,
@ -24,7 +23,7 @@ async def call_alphafold2(
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 600,
max_retries: int = 3
max_retries: int = 3,
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2 API.
@ -59,16 +58,13 @@ async def call_alphafold2(
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction,
"skip_template_search": skip_template_search
"skip_template_search": skip_template_search,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
url, json=data, headers=headers, timeout=timeout
) as response:
if response.status == 200:
return await response.json()
@ -81,7 +77,7 @@ async def call_alphafold2(
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
timeout=timeout,
)
else:
logger.error("No request ID in response headers")
@ -93,16 +89,18 @@ async def call_alphafold2(
return None
except Exception as e:
import traceback
logger.error(f"Error calling AlphaFold2 API: {e}")
logger.error(traceback.format_exc())
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
timeout: int = 60,
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
@ -121,18 +119,20 @@ async def _poll_job_status(
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
f"{status_url}/{req_id}", headers=headers, timeout=timeout
) as response:
if response.status == 200:
logger.info(f"AlphaFold2 job {req_id} completed")
return await response.json()
elif response.status == 202:
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
logger.debug(
f"AlphaFold2 job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking AlphaFold2 job status: {response.status}")
logger.error(
f"Error checking AlphaFold2 job status: {response.status}"
)
text = await response.text()
logger.error(f"Response: {text}")
return None

View file

@ -1,16 +1,16 @@
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import json
import logging
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
"""
Splits a string containing concatenated PDB file contents.
@ -35,7 +35,9 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
return [pdb for pdb in pdbs if pdb]
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
def calculate_plddt_from_pdb_string(
pdb_string: str,
) -> Tuple[float, List[float], Dict[str, List[float]]]:
total_plddt = 0.0
ca_atom_count = 0
plddt_scores_per_ca: List[float] = []
@ -67,10 +69,11 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
average_plddt = total_plddt / ca_atom_count
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_pdb_and_scores_from_api(
pdb_contents: List[str],
job_id: str,
api_response_json: Optional[Dict[str, Any]] = None
api_response_json: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, Any]]:
"""
Processes a list of PDB strings received from the API.
@ -78,14 +81,19 @@ async def _process_pdb_and_scores_from_api(
- Does NOT save files to disk.
- Returns a dictionary containing a list of structures, each with its PDB content and scores.
"""
results: Dict[str, Any] = {
"structures": []
}
results: Dict[str, Any] = {"structures": []}
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
if (
not pdb_contents
or not isinstance(pdb_contents, list)
or not all(isinstance(s, str) for s in pdb_contents)
):
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []}
return {
"success": False,
"error": "No valid PDB content strings from API.",
"structures": [],
}
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
@ -94,12 +102,11 @@ async def _process_pdb_and_scores_from_api(
logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.")
continue
structure_data: Dict[str, Any] = {
"model_index": i,
"pdb_content": pdb_str
}
structure_data: Dict[str, Any] = {"model_index": i, "pdb_content": pdb_str}
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
avg_plddt, plddts_per_ca_residue, plddts_by_chain = (
calculate_plddt_from_pdb_string(pdb_str)
)
structure_data["average_plddt"] = avg_plddt
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue
@ -116,13 +123,21 @@ async def _process_pdb_and_scores_from_api(
results["structures"].append(structure_data)
if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
logger.info(
f"Successfully processed and calculated pLDDTs for "
f"{len(results['structures'])} structures for job {job_id}."
)
else:
logger.warning(f"No structures were processed for job {job_id}.")
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []}
return {
"success": True,
"message": "No PDB structures found in API response to process.",
"structures": [],
}
return results
async def call_alphafold2_multimer(
sequences: List[str],
api_key: str,
@ -135,7 +150,7 @@ async def call_alphafold2_multimer(
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30,
timeout: int = 3600
timeout: int = 3600,
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2-Multimer API.
@ -155,7 +170,7 @@ async def call_alphafold2_multimer(
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction
"relax_prediction": relax_prediction,
}
if selected_models is not None:
data["selected_models"] = selected_models
@ -165,10 +180,7 @@ async def call_alphafold2_multimer(
initial_post_timeout = min(timeout, 600)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=initial_post_timeout
url, json=data, headers=headers, timeout=initial_post_timeout
) as response:
if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.")
@ -177,130 +189,239 @@ async def call_alphafold2_multimer(
if "application/json" in content_type:
api_response_json_payload = await response.json()
if not isinstance(api_response_json_payload, list):
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}")
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
if (
isinstance(api_response_json_payload, dict)
and "error" in api_response_json_payload
):
logger.error(
f"Sync API call returned error: "
f"{api_response_json_payload['error']}"
)
return {
"success": False,
"error": api_response_json_payload["error"],
"detail": api_response_json_payload.get(
"detail", ""
),
}
return {
"success": False,
"error": "Sync JSON response not a list of PDBs as expected.",
}
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
return await _process_pdb_and_scores_from_api(
pdb_contents=api_response_json_payload,
job_id=req_id_sync,
api_response_json=None
api_response_json=None,
)
else:
err_text = await response.text()
logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}")
return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text}
logger.error(
f"Sync response unexpected content type: {content_type}. "
f"Response: {err_text[:500]}"
)
return {
"success": False,
"error": f"Sync response unexpected content type: {content_type}",
"detail": err_text,
}
elif response.status == 202:
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
logger.info(
f"AlphaFold2-Multimer job submitted, request ID: {req_id}"
)
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
overall_timeout=timeout
overall_timeout=timeout,
)
else:
logger.error("No request ID in 202 response headers")
return {"success": False, "error": "No request ID in 202 response headers"}
return {
"success": False,
"error": "No request ID in 202 response headers",
}
else:
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
logger.error(
f"Error calling AlphaFold2-Multimer API (POST): {response.status}"
)
text = await response.text()
logger.error(f"Response: {text}")
return {"success": False, "error": f"Error calling API: {response.status}", "detail": text}
return {
"success": False,
"error": f"Error calling API: {response.status}",
"detail": text,
}
except asyncio.TimeoutError:
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
logger.error("Timeout during AlphaFold2-Multimer API (initial POST).")
return {"success": False, "error": "Timeout during initial API request"}
except Exception as e:
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
logger.error(
f"Exception during AlphaFold2-Multimer API call (initial POST): {e}",
exc_info=True,
)
return {"success": False, "error": f"Exception during API call: {str(e)}"}
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 30,
overall_timeout: int = 3600
overall_timeout: int = 3600,
) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time()
per_status_request_timeout = 600
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
logger.info(
f"Polling job {req_id}. Individual status check timeout: "
f"{per_status_request_timeout}s, Polling interval: {polling_interval}s, "
f"Overall timeout: {overall_timeout}s"
)
while True:
current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time
if elapsed_time >= overall_timeout:
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
logger.error(
f"Overall polling timeout of {overall_timeout}s exceeded for "
f"job {req_id}."
)
return {"success": False, "error": "Overall polling timeout exceeded."}
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
current_status_check_timeout = min(per_status_request_timeout, remaining_time_for_overall_timeout)
current_status_check_timeout = min(
per_status_request_timeout, remaining_time_for_overall_timeout
)
if current_status_check_timeout <= 0:
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
return {"success": False, "error": "Not enough time for status check within overall timeout."}
logger.error(
f"Not enough time left for another status check for job {req_id} "
f"within overall_timeout."
)
return {
"success": False,
"error": "Not enough time for status check within overall timeout.",
}
try:
async with aiohttp.ClientSession() as session:
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
logger.debug(
f"Checking status for {req_id} with timeout "
f"{current_status_check_timeout}s."
)
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=current_status_check_timeout
timeout=current_status_check_timeout,
) as response:
if response.status == 200:
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
if response.content_type == 'application/json':
logger.info(
f"AlphaFold2-Multimer job {req_id} completed (status 200)."
)
if response.content_type == "application/json":
try:
api_response_json_payload = await response.json()
if not isinstance(api_response_json_payload, list):
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}")
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
return {"success": False, "error": "API response was not a list of PDB strings."}
if (
isinstance(api_response_json_payload, dict)
and "error" in api_response_json_payload
):
logger.error(
f"Job {req_id}: API returned error: "
f"{api_response_json_payload['error']}"
)
return {
"success": False,
"error": api_response_json_payload["error"],
"detail": api_response_json_payload.get(
"detail", ""
),
}
logger.error(
f"Job {req_id}: Expected API response to be a list of PDB strings, "
f"got {type(api_response_json_payload)}."
)
return {
"success": False,
"error": "API response was not a list of PDB strings.",
}
return await _process_pdb_and_scores_from_api(
pdb_contents=api_response_json_payload,
job_id=req_id,
api_response_json=None
api_response_json=None,
)
except json.JSONDecodeError:
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
logger.error(
f"Job {req_id}: Failed to decode JSON response from API.",
exc_info=True,
)
raw_text = await response.text()
return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]}
return {
"success": False,
"error": "Failed to decode JSON response.",
"detail": raw_text[:500],
}
else:
raw_text = await response.text()
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}")
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text}
logger.error(
f"Job {req_id}: Unexpected content type {response.content_type}. "
f"Expected application/json. Response: {raw_text[:500]}"
)
return {
"success": False,
"error": f"Unexpected content type: {response.content_type}",
"detail": raw_text,
}
elif response.status == 202:
try:
job_status_json = await response.json()
percent_complete = job_status_json.get('percentComplete', 'N/A')
status_message = job_status_json.get('status', 'running')
percent_complete = job_status_json.get(
"percentComplete", "N/A"
)
status_message = job_status_json.get("status", "running")
logger.debug(
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
f"Job {req_id} status: {status_message} ({percent_complete}% complete). "
f"Polling again in {polling_interval}s."
)
except (aiohttp.ContentTypeError, json.JSONDecodeError):
logger.debug(
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). "
f"Polling again in {polling_interval}s."
)
await asyncio.sleep(polling_interval)
else:
text = await response.text()
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: HTTP {response.status} - {text}")
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
logger.error(
f"Error checking AlphaFold2-Multimer job status {req_id}: "
f"HTTP {response.status} - {text}"
)
return {
"success": False,
"error": f"Status check failed with HTTP {response.status}",
"detail": text,
}
except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
logger.warning(
f"Client-side timeout ({current_status_check_timeout}s) during status check for "
f"job {req_id}. Retrying poll after {polling_interval}s sleep."
)
await asyncio.sleep(polling_interval)
except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
logger.error(
f"Client error polling job status for {req_id}: {e}. "
f"Retrying poll after {polling_interval}s.",
exc_info=True,
)
await asyncio.sleep(polling_interval)
except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
logger.error(
f"Unexpected error polling job status {req_id}: {e}", exc_info=True
)
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_proteinmpnn(
input_pdb: str,
api_key: str,
@ -20,7 +19,7 @@ async def call_proteinmpnn(
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
timeout: int = 60,
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM ProteinMPNN API.
@ -49,16 +48,13 @@ async def call_proteinmpnn(
"input_pdb": input_pdb,
"ca_only": ca_only,
"use_soluble_model": use_soluble_model,
"sampling_temp": sampling_temp
"sampling_temp": sampling_temp,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
url, json=data, headers=headers, timeout=timeout
) as response:
if response.status == 200:
return await response.json()
@ -71,7 +67,7 @@ async def call_proteinmpnn(
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
timeout=timeout,
)
else:
logger.error("No request ID in response headers")
@ -85,12 +81,13 @@ async def call_proteinmpnn(
logger.error(f"Error calling ProteinMPNN API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
timeout: int = 60,
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
@ -109,18 +106,20 @@ async def _poll_job_status(
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
f"{status_url}/{req_id}", headers=headers, timeout=timeout
) as response:
if response.status == 200:
logger.info(f"ProteinMPNN job {req_id} completed")
return await response.json()
elif response.status == 202:
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
logger.debug(
f"ProteinMPNN job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking ProteinMPNN job status: {response.status}")
logger.error(
f"Error checking ProteinMPNN job status: {response.status}"
)
text = await response.text()
logger.error(f"Response: {text}")
return None

View file

@ -1,16 +1,15 @@
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
import logging
from typing import Any, Dict, List, Optional
import aiohttp
logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_rfdiffusion(
input_pdb: str,
api_key: str,
@ -20,7 +19,7 @@ async def call_rfdiffusion(
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
timeout: int = 60,
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM RFDiffusion API.
@ -45,10 +44,7 @@ async def call_rfdiffusion(
"NVCF-POLL-SECONDS": "300",
}
data = {
"input_pdb": input_pdb,
"diffusion_steps": diffusion_steps
}
data = {"input_pdb": input_pdb, "diffusion_steps": diffusion_steps}
if contigs:
data["contigs"] = contigs
@ -58,10 +54,7 @@ async def call_rfdiffusion(
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
url, json=data, headers=headers, timeout=timeout
) as response:
if response.status == 200:
return await response.json()
@ -74,7 +67,7 @@ async def call_rfdiffusion(
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
timeout=timeout,
)
else:
logger.error("No request ID in response headers")
@ -88,12 +81,13 @@ async def call_rfdiffusion(
logger.error(f"Error calling RFDiffusion API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
timeout: int = 60,
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
@ -112,18 +106,20 @@ async def _poll_job_status(
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
f"{status_url}/{req_id}", headers=headers, timeout=timeout
) as response:
if response.status == 200:
logger.info(f"RFDiffusion job {req_id} completed")
return await response.json()
elif response.status == 202:
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
logger.debug(
f"RFDiffusion job {req_id} still running, polling..."
)
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking RFDiffusion job status: {response.status}")
logger.error(
f"Error checking RFDiffusion job status: {response.status}"
)
text = await response.text()
logger.error(f"Response: {text}")
return None

View file

@ -1,25 +1,25 @@
import logging
from typing import Dict
logger = logging.getLogger(__name__)
SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein.
SYSTEM_PROMPT = (
"You are a specialized AI system for de novo protein design via a staged simulation loop. "
"Your objective is to generate binder sequences that are structurally and functionally "
"optimized to bind a given target protein.\n\n"
"You will be guided through a multi-step pipeline:\n\n"
"1. Structure prediction (AlphaFold)\n"
"2. Binder backbone generation (RFdiffusion)\n"
"3. Sequence design (ProteinMPNN)\n"
"4. Structure evaluation (AlphaFold-Multimer)\n"
"5. Feedback loop\n\n"
"You must always:\n"
"- Respect the required file format for each tool (e.g., FASTA, PDB).\n"
"- Structure your outputs cleanly so they can be parsed and executed programmatically.\n"
"- Be explicit in all configuration steps (e.g., contigs, hotspots).\n"
"- Minimize ambiguity or verbosity; prefer concise and functional outputs.\n"
"- Reason step-by-step when appropriate."
)
You will be guided through a multi-step pipeline:
1. Structure prediction (AlphaFold)
2. Binder backbone generation (RFdiffusion)
3. Sequence design (ProteinMPNN)
4. Structure evaluation (AlphaFold-Multimer)
5. Feedback loop
You must always:
- Respect the required file format for each tool (e.g., FASTA, PDB).
- Structure your outputs cleanly so they can be parsed and executed programmatically.
- Be explicit in all configuration steps (e.g., contigs, hotspots).
- Minimize ambiguity or verbosity; prefer concise and functional outputs.
- Reason step-by-step when appropriate.
"""
def construct_user_prompt(state: dict) -> str:
"""
@ -42,19 +42,20 @@ def construct_user_prompt(state: dict) -> str:
"You must provide the 'sequence' argument."
)
elif internal_step == 1:
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
chain_details = state.get("target_chain_details", {})
if chain_details:
chain_info_parts = []
for chain_id, details in chain_details.items():
min_r = details.get('min_residue', 'N/A')
max_r = details.get('max_residue', 'N/A')
l = details.get('length', 'N/A')
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
min_r = details.get("min_residue", "N/A")
max_r = details.get("max_residue", "N/A")
length = details.get("length", "N/A")
chain_info_parts.append(
f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
)
chain_info_str = "\n- ".join(chain_info_parts)
if chain_info_str:
chain_info_str = "- " + chain_info_str
chain_info_str = "- " + chain_info_str
else:
chain_info_str = "Chain information not available or PDB not yet processed."
@ -62,19 +63,26 @@ def construct_user_prompt(state: dict) -> str:
f"The 3D structure of the target protein has been predicted.\n"
f"Target Protein Chain Details:\n{chain_info_str}\n\n"
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB "
"and segments for the new binder. "
"Examples:\n"
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: "
"'A10-100/0 60'\n"
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B "
"from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
"You MUST use the chain IDs and residue ranges exactly as provided in the "
"'Target Protein Chain Details' above. "
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist "
"on the target as per the details above."
)
elif internal_step == 2:
binder_pdb_content = state.get("binder_backbone_pdb_content")
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.")
binder_pdb_preview = state.get(
"binder_pdb_preview", "Binder PDB preview not available."
)
binder_chain_info_str = "Binder chain information not available."
if binder_pdb_content:
@ -83,10 +91,12 @@ def construct_user_prompt(state: dict) -> str:
if binder_chain_details:
chain_info_parts = []
for cID, d_details in binder_chain_details.items():
min_r = d_details.get('min_residue', 'N/A')
max_r = d_details.get('max_residue', 'N/A')
l = d_details.get('length', 'N/A')
chain_info_parts.append(f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
min_r = d_details.get("min_residue", "N/A")
max_r = d_details.get("max_residue", "N/A")
length = d_details.get("length", "N/A")
chain_info_parts.append(
f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
)
binder_chain_info_str = "\n- ".join(chain_info_parts)
if binder_chain_info_str:
binder_chain_info_str = "- " + binder_chain_info_str
@ -99,7 +109,8 @@ def construct_user_prompt(state: dict) -> str:
user_prompt_str = (
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
f"Binder chain information:\n{binder_chain_info_str}.\n"
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
"Now, design an optimal amino acid sequence for this binder backbone using the "
"'design_binder_sequence_proteinmpnn' tool. "
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
)
elif internal_step == 3:
@ -110,27 +121,39 @@ def construct_user_prompt(state: dict) -> str:
if len(designed_binder_seq_data) == 1:
binder_display_str = designed_binder_seq_data[0]
else:
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
for i, s in enumerate(designed_binder_seq_data)])
binder_display_str = (
f"{len(designed_binder_seq_data)} chains: "
+ ", ".join(
[
f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
for i, s in enumerate(designed_binder_seq_data)
]
)
)
elif isinstance(designed_binder_seq_data, str):
binder_display_str = designed_binder_seq_data
binder_display_str = designed_binder_seq_data
user_prompt_str = (
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
f"The original target sequence was: {target_sequence[:60]}...\n"
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
"'evaluate_binder_complex_alphafold2_multimer' tool. "
"Finally, evaluate the binding complex of the original target protein and ALL chains of this "
"designed binder using the 'evaluate_binder_complex_alphafold2_multimer' tool. "
"You can optionally specify 'relax_prediction' (default is True)."
)
else:
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
user_prompt_str = (
"The protein design workflow is complete. No further actions required by you for this item. "
"If successful, the key metric was the pLDDT of the complex."
)
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
retry_prefix = "Your previous attempt at this step was not successful. "
if state.get("previous_tool_error_message"):
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
retry_prefix += (
"Please review the requirements and PDB details carefully and try again to correctly use "
"the expected tool.\n\n"
)
user_prompt_str = retry_prefix + user_prompt_str
return user_prompt_str

View file

@ -0,0 +1,92 @@
PREDICT_TARGET_STRUCTURE_TOOL = {
"type": "function",
"function": {
"name": "predict_target_structure_alphafold2",
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
"parameters": {
"type": "object",
"properties": {
"sequence": {
"type": "string",
"description": "Amino acid sequence of the target protein.",
},
},
"required": ["sequence"],
},
},
}
DESIGN_BINDER_BACKBONE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_backbone_rfdiffusion",
"description": (
"Generates a novel protein binder backbone using RFDiffusion, "
"conditioned on the target protein structure."
),
"parameters": {
"type": "object",
"properties": {
"contigs": {
"type": "string",
"description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70').",
},
"hotspot_residues": {
"type": "array",
"items": {"type": "string"},
"description": "Optional hotspot residues (e.g., ['A50', 'A52']).",
},
},
"required": ["contigs"],
},
},
}
DESIGN_BINDER_SEQUENCE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_sequence_proteinmpnn",
"description": "Designs an amino acid sequence for the generated binder backbone.",
"parameters": {
"type": "object",
"properties": {
"sampling_temp": {
"type": "array",
"items": {"type": "number"},
"description": (
"Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."
),
}
},
"required": [],
},
},
}
EVALUATE_COMPLEX_TOOL = {
"type": "function",
"function": {
"name": "evaluate_binder_complex_alphafold2_multimer",
"description": (
"Predicts the complex structure of target and designed binder, "
"providing quality scores."
),
"parameters": {
"type": "object",
"properties": {
"relax_prediction": {
"type": "boolean",
"description": "Whether to relax the prediction. Default True.",
}
},
"required": [],
},
},
}
ALL_TOOLS_LIST = [
PREDICT_TARGET_STRUCTURE_TOOL,
DESIGN_BINDER_BACKBONE_TOOL,
DESIGN_BINDER_SEQUENCE_TOOL,
EVALUATE_COMPLEX_TOOL,
]

View file

@ -1,19 +1,26 @@
import logging
import json
import re
from typing import Dict, Any, List, Tuple, Optional, Union
from pathlib import Path
from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2
from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion
from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn
from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer
from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details
from typing import Dict, List, Optional, Tuple
from .models.alphafold2 import call_alphafold2
from .models.alphafold2_multimer import call_alphafold2_multimer
from .models.proteinmpnn import call_proteinmpnn
from .models.rfdiffusion import call_rfdiffusion
from .utils.pdb_utils import get_pdb_chain_details
logger = logging.getLogger(__name__)
class ToolExecutor:
def __init__(self, nim_api_key: str, api_timeout: int, polling_interval: int,
output_dir: Path, debug_protein_design_calls: bool):
def __init__(
self,
nim_api_key: str,
api_timeout: int,
polling_interval: int,
output_dir: Path,
debug_protein_design_calls: bool,
):
self.nim_api_key = nim_api_key
self.api_timeout = api_timeout
self.polling_interval = polling_interval
@ -21,18 +28,23 @@ class ToolExecutor:
self.debug_protein_design_calls = debug_protein_design_calls
self._debug_af2m_call_count = 0
def _validate_rfd_contigs(self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
def _validate_rfd_contigs(
self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]
) -> Optional[str]:
"""
Validates the RFDiffusion contigs string against target PDB chain details.
Returns None if valid, or an error message string if invalid.
"""
if not contigs_str: return "Contigs string is empty."
if not contigs_str:
return "Contigs string is empty."
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
active_contig_parts = contigs_str.split('/') # Split by binder definition markers
active_contig_parts = contigs_str.split(
"/"
) # Split by binder definition markers
for part in active_contig_parts:
chain_segments_in_part = part.strip().split(' ')
chain_segments_in_part = part.strip().split(" ")
for segment_text in chain_segments_in_part:
segment_text = segment_text.strip()
if not segment_text or segment_text.isdigit():
@ -45,42 +57,61 @@ class ToolExecutor:
seg_end = int(seg_end_str)
if seg_chain_id not in target_chain_details:
return f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. Valid: {list(target_chain_details.keys())}."
return (
f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. "
f"Valid: {list(target_chain_details.keys())}."
)
chain_min = target_chain_details[seg_chain_id]["min_residue"]
chain_max = target_chain_details[seg_chain_id]["max_residue"]
if not (chain_min <= seg_start <= chain_max and chain_min <= seg_end <= chain_max and seg_start <= seg_end):
return (f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}.")
if not (
chain_min <= seg_start <= chain_max
and chain_min <= seg_end <= chain_max
and seg_start <= seg_end
):
return (
f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}."
)
return None
def _validate_rfd_hotspots(self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
def _validate_rfd_hotspots(
self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]
) -> Optional[str]:
"""
Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details.
Returns None if valid, or an error message string if invalid.
"""
if not hotspot_list: return None
if not hotspot_list:
return None
hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)")
for hotspot_str in hotspot_list:
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
if not match:
return f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
return (
f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
)
hs_chain_id, hs_res_num_str = match.groups()
hs_res_num = int(hs_res_num_str)
if hs_chain_id not in target_chain_details:
return f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. Valid: {list(target_chain_details.keys())}."
return (
f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. "
f"Valid: {list(target_chain_details.keys())}."
)
chain_min = target_chain_details[hs_chain_id]["min_residue"]
chain_max = target_chain_details[hs_chain_id]["max_residue"]
if not (chain_min <= hs_res_num <= chain_max):
return (f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}.")
return (
f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}."
)
return None
async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict:
@ -96,13 +127,18 @@ class ToolExecutor:
state_updates = {}
if self.debug_protein_design_calls:
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
logger.warning(
f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}."
)
module_dir = Path(__file__).parent
fixed_pdb_path = module_dir / "debug_target.pdb"
if not fixed_pdb_path.exists():
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
tool_output = {"success": False, "error": f"Debug mode failed: Required file {fixed_pdb_path} not found."}
tool_output = {
"success": False,
"error": f"Debug mode failed: Required file {fixed_pdb_path} not found.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
with open(fixed_pdb_path, "r") as f:
@ -115,7 +151,10 @@ class ToolExecutor:
state_updates["target_pdb_preview"] = pdb_preview
state_updates["target_structure_predicted"] = True
debug_pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
debug_pdb_path = (
self.output_dir
/ f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
)
with open(debug_pdb_path, "w") as f:
f.write(pdb_content)
logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}")
@ -124,25 +163,31 @@ class ToolExecutor:
"success": True,
"message": "DEBUG MODE: Used fixed PDB for AlphaFold2.",
"target_pdb_preview": pdb_preview,
"saved_pdb_path": str(debug_pdb_path)
"saved_pdb_path": str(debug_pdb_path),
}
return {"tool_output": tool_output, "state_updates": state_updates}
sequence_from_llm = args.get("sequence")
if not sequence_from_llm:
tool_output = {"success": False, "error": "Missing 'sequence' for AlphaFold2."}
tool_output = {
"success": False,
"error": "Missing 'sequence' for AlphaFold2.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
actual_sequence_to_use = target_sequence_from_state
if sequence_from_llm != target_sequence_from_state:
logger.warning(
f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. "
f"However, this tool will use the canonical target sequence from the workflow state: '{target_sequence_from_state[:20]}...'"
f"However, this tool will use the canonical target sequence from the workflow state: "
f"'{target_sequence_from_state[:20]}...'"
)
api_result = await call_alphafold2(
sequence=actual_sequence_to_use, api_key=self.nim_api_key,
timeout=self.api_timeout, polling_interval=self.polling_interval
sequence=actual_sequence_to_use,
api_key=self.nim_api_key,
timeout=self.api_timeout,
polling_interval=self.polling_interval,
)
if api_result and isinstance(api_result, list) and api_result[0]:
pdb_content = api_result[0]
@ -153,20 +198,34 @@ class ToolExecutor:
state_updates["target_pdb_preview"] = pdb_preview
state_updates["target_structure_predicted"] = True
pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
with open(pdb_path, "w") as f: f.write(pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
pdb_path = (
self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
)
with open(pdb_path, "w") as f:
f.write(pdb_content)
logger.info(
f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. "
f"Chain details: {chain_details}"
)
tool_output = {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview, "saved_pdb_path": str(pdb_path)}
tool_output = {
"success": True,
"message": "AlphaFold2 prediction complete.",
"target_pdb_preview": pdb_preview,
"saved_pdb_path": str(pdb_path),
}
else:
error_detail = api_result.get("error", "AlphaFold2 prediction failed.") if isinstance(api_result, dict) else "AlphaFold2 prediction failed."
error_detail = (
api_result.get("error", "AlphaFold2 prediction failed.")
if isinstance(api_result, dict)
else "AlphaFold2 prediction failed."
)
logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}")
tool_output = {"success": False, "error": error_detail}
state_updates["target_structure_predicted"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict:
"""
Runs RFDiffusion for binder backbone design. Returns structured output with
@ -182,30 +241,55 @@ class ToolExecutor:
contigs_str_from_llm = args.get("contigs")
if not target_pdb_content:
tool_output = {"success": False, "error": "Target PDB not found in state for RFDiffusion."}
tool_output = {
"success": False,
"error": "Target PDB not found in state for RFDiffusion.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
if not contigs_str_from_llm:
tool_output = {"success": False, "error": "Missing 'contigs' for RFDiffusion."}
tool_output = {
"success": False,
"error": "Missing 'contigs' for RFDiffusion.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
validation_error = self._validate_rfd_contigs(contigs_str_from_llm, target_chain_details)
validation_error = self._validate_rfd_contigs(
contigs_str_from_llm, target_chain_details
)
if validation_error:
logger.warning(f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. Contigs: '{contigs_str_from_llm}'")
tool_output = {"success": False, "error": f"Invalid contigs: {validation_error}"}
logger.warning(
f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. "
f"Contigs: '{contigs_str_from_llm}'"
)
tool_output = {
"success": False,
"error": f"Invalid contigs: {validation_error}",
}
return {"tool_output": tool_output, "state_updates": state_updates}
hotspot_residues = args.get("hotspot_residues")
if hotspot_residues:
hotspot_validation_error = self._validate_rfd_hotspots(hotspot_residues, target_chain_details)
hotspot_validation_error = self._validate_rfd_hotspots(
hotspot_residues, target_chain_details
)
if hotspot_validation_error:
logger.warning(f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. Hotspots: {hotspot_residues}")
tool_output = {"success": False, "error": f"Invalid hotspots: {hotspot_validation_error}"}
logger.warning(
f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. "
f"Hotspots: {hotspot_residues}"
)
tool_output = {
"success": False,
"error": f"Invalid hotspots: {hotspot_validation_error}",
}
return {"tool_output": tool_output, "state_updates": state_updates}
api_result = await call_rfdiffusion(
input_pdb=target_pdb_content, api_key=self.nim_api_key,
contigs=contigs_str_from_llm, hotspot_res=hotspot_residues,
timeout=self.api_timeout, polling_interval=self.polling_interval
input_pdb=target_pdb_content,
api_key=self.nim_api_key,
contigs=contigs_str_from_llm,
hotspot_res=hotspot_residues,
timeout=self.api_timeout,
polling_interval=self.polling_interval,
)
if api_result and "output_pdb" in api_result:
@ -217,20 +301,34 @@ class ToolExecutor:
state_updates["binder_pdb_preview"] = binder_pdb_preview
state_updates["binder_backbone_designed"] = True
pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
with open(pdb_path, "w") as f: f.write(binder_pdb)
pdb_path = (
self.output_dir
/ f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
)
with open(pdb_path, "w") as f:
f.write(binder_pdb)
logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}")
tool_output = {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb_preview, "saved_pdb_path": str(pdb_path)}
tool_output = {
"success": True,
"message": "RFDiffusion complete.",
"binder_backbone_pdb_preview": binder_pdb_preview,
"saved_pdb_path": str(pdb_path),
}
else:
error_detail = api_result.get("error", "RFDiffusion failed.") if isinstance(api_result, dict) else "RFDiffusion failed."
logger.error(f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}")
error_detail = (
api_result.get("error", "RFDiffusion failed.")
if isinstance(api_result, dict)
else "RFDiffusion failed."
)
logger.error(
f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail}
state_updates["binder_backbone_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
"""
Runs ProteinMPNN for binder sequence design. Returns structured output with
@ -244,20 +342,33 @@ class ToolExecutor:
state_updates = {}
if not binder_pdb:
tool_output = {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
tool_output = {
"success": False,
"error": "Binder backbone PDB not found for ProteinMPNN.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
sampling_temp_list = args.get("sampling_temp", [0.1])
api_result = await call_proteinmpnn(
input_pdb=binder_pdb, api_key=self.nim_api_key,
input_pdb=binder_pdb,
api_key=self.nim_api_key,
sampling_temp=sampling_temp_list,
timeout=self.api_timeout, polling_interval=self.polling_interval
timeout=self.api_timeout,
polling_interval=self.polling_interval,
)
if not (api_result and "mfasta" in api_result):
error_detail = api_result.get("error", "ProteinMPNN call failed or no mfasta in result.") if isinstance(api_result, dict) else "PMPNN call failed"
logger.error(f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}")
error_detail = (
api_result.get(
"error", "ProteinMPNN call failed or no mfasta in result."
)
if isinstance(api_result, dict)
else "PMPNN call failed"
)
logger.error(
f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail}
state_updates["binder_sequence_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
@ -268,12 +379,15 @@ class ToolExecutor:
current_sequence_parts: List[str] = []
for line_content in fasta_content.splitlines():
line = line_content.strip()
if not line: continue
if not line:
continue
if line.startswith(">"):
if current_header and current_sequence_parts:
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
global_score = (
float(score_match.group(1)) if score_match else -float("inf")
)
entries.append((global_score, current_header, full_sequence_line))
current_header = line
current_sequence_parts = []
@ -282,7 +396,7 @@ class ToolExecutor:
if current_header and current_sequence_parts:
full_sequence_line = "".join(current_sequence_parts)
score_match = re.search(r"global_score=([-\d.]+)", current_header)
global_score = float(score_match.group(1)) if score_match else -float('inf')
global_score = float(score_match.group(1)) if score_match else -float("inf")
entries.append((global_score, current_header, full_sequence_line))
if not entries:
@ -292,36 +406,59 @@ class ToolExecutor:
entries.sort(key=lambda x: x[0], reverse=True)
best_global_score, best_header, best_full_sequence_line = entries[0]
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'")
logger.info(
f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) "
f"from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'"
)
parsed_binder_chains = [s.strip() for s in best_full_sequence_line.split('/') if s.strip()]
parsed_binder_chains = [
s.strip() for s in best_full_sequence_line.split("/") if s.strip()
]
if not parsed_binder_chains or not all(s and s.isalpha() and s.isupper() for s in parsed_binder_chains):
tool_output = {"success": False, "error": f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. Parsed: {parsed_binder_chains}"}
if not parsed_binder_chains or not all(
s and s.isalpha() and s.isupper() for s in parsed_binder_chains
):
tool_output = {
"success": False,
"error": (
f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. "
f"Parsed: {parsed_binder_chains}"
),
}
state_updates["binder_sequence_designed"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
state_updates["designed_binder_sequence"] = parsed_binder_chains
state_updates["binder_sequence_designed"] = True
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
with open(fasta_path, "w") as f: f.write(fasta_content)
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
fasta_path = (
self.output_dir
/ f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
)
with open(fasta_path, "w") as f:
f.write(fasta_content)
logger.info(
f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. "
f"Selected binder chains: {parsed_binder_chains}"
)
preview = parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
preview = (
parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
)
if len(parsed_binder_chains) > 1:
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
tool_output = {
"success": True,
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
"message": (
f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f})."
),
"designed_binder_sequence_list": parsed_binder_chains,
"designed_binder_sequence_preview": preview,
"saved_fasta_path": str(fasta_path)
"saved_fasta_path": str(fasta_path),
}
return {"tool_output": tool_output, "state_updates": state_updates}
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
item_id = workflow_state["item_id"]
current_internal_step = workflow_state["current_internal_step"]
@ -331,16 +468,31 @@ class ToolExecutor:
tool_output = {}
state_updates = {}
if not target_seq or not designed_binder_chains_list or not isinstance(designed_binder_chains_list, list):
tool_output = {"success": False, "error": "Missing or invalid sequences for AF2-Multimer."}
if (
not target_seq
or not designed_binder_chains_list
or not isinstance(designed_binder_chains_list, list)
):
tool_output = {
"success": False,
"error": "Missing or invalid sequences for AF2-Multimer.",
}
return {"tool_output": tool_output, "state_updates": state_updates}
all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list
for i, seq_to_validate in enumerate(all_input_sequences_for_multimer):
if not (seq_to_validate and isinstance(seq_to_validate, str) and seq_to_validate.isalpha() and seq_to_validate.isupper()):
error_msg = (f"Sequence {i+1} (part of target/binder complex) is invalid: "
f"'{str(seq_to_validate)[:30]}...'. Contains non-alpha/lowercase, is empty, or not a string.")
if not (
seq_to_validate
and isinstance(seq_to_validate, str)
and seq_to_validate.isalpha()
and seq_to_validate.isupper()
):
error_msg = (
f"Sequence {i+1} (part of target/binder complex) is invalid: "
f"'{str(seq_to_validate)[:30]}...'. "
f"Contains non-alpha/lowercase, is empty, or not a string."
)
logger.error(f"Workflow {item_id}: {error_msg}")
tool_output = {"success": False, "error": error_msg}
return {"tool_output": tool_output, "state_updates": state_updates}
@ -350,28 +502,41 @@ class ToolExecutor:
if self.debug_protein_design_calls:
self._debug_af2m_call_count += 1
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})"
success_message = (
f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality "
f"mock results (call #{self._debug_af2m_call_count})"
)
debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb"
debug_pdb_path = self.output_dir / debug_pdb_filename
try:
with open(debug_pdb_path, "w") as f:
f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}")
f.write(
f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n"
)
logger.info(
f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}"
)
state_updates["complex_pdb_content_path"] = str(debug_pdb_path)
except IOError as e:
logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}")
logger.error(
f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}"
)
# If saving fails, don't set the path, but can still proceed with mock pLDDT
state_updates["complex_pdb_content_path"] = None
state_updates["af2_multimer_plddt"] = mock_plddt
state_updates["complex_evaluated"] = True
tool_output = {
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
"success": True,
"message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
"plddt": mock_plddt,
"complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None
"complex_file_path": (
str(debug_pdb_path)
if state_updates["complex_pdb_content_path"]
else None
),
}
return {"tool_output": tool_output, "state_updates": state_updates}
@ -380,36 +545,57 @@ class ToolExecutor:
api_key=self.nim_api_key,
relax_prediction=relax,
timeout=self.api_timeout,
polling_interval=self.polling_interval
polling_interval=self.polling_interval,
)
if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False):
if api_result is None or (
isinstance(api_result, dict) and api_result.get("success") is False
):
error_detail = "AF2-Multimer call failed or returned None."
if isinstance(api_result, dict):
error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.")
error_detail = api_result.get(
"error", "AF2-Multimer call failed with unspecified error."
)
detail_info = api_result.get("detail", "")
if detail_info: error_detail += f" Details: {detail_info}"
if detail_info:
error_detail += f" Details: {detail_info}"
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}")
logger.error(
f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. "
f"API Result: {api_result}"
)
tool_output = {"success": False, "error": error_detail}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
all_structures_info = api_result.get("structures")
if not all_structures_info or not isinstance(all_structures_info, list):
message = api_result.get("message", "No structures returned from AF2-Multimer process.")
message = api_result.get(
"message", "No structures returned from AF2-Multimer process."
)
logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}")
if not all_structures_info and isinstance(all_structures_info, list):
tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None}
state_updates["af2_multimer_plddt"] = 0.0
state_updates["complex_evaluated"] = True
state_updates["complex_pdb_content_path"] = None
tool_output = {
"success": True,
"message": (
"AF2-Multimer ran, but no PDB structures were produced by the API."
),
"plddt": 0.0,
"complex_file_path": None,
}
state_updates["af2_multimer_plddt"] = 0.0
state_updates["complex_evaluated"] = True
state_updates["complex_pdb_content_path"] = None
else:
tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."}
tool_output = {
"success": False,
"error": (
"AF2-Multimer returned unexpected data or no structures."
),
}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
best_structure_info = None
highest_plddt = -1.0
@ -420,8 +606,14 @@ class ToolExecutor:
best_structure_info = struct_info
if best_structure_info is None:
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.")
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
logger.error(
f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, "
f"though structures were present."
)
tool_output = {
"success": False,
"error": ("No valid structure with pLDDT in AF2-Multimer results."),
}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
@ -430,53 +622,81 @@ class ToolExecutor:
best_model_idx = best_structure_info.get("model_index", "NA")
if not best_pdb_content:
logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.")
tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."}
logger.error(
f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, "
f"pLDDT {best_plddt:.2f}) found, but PDB content is missing."
)
tool_output = {
"success": False,
"error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content.",
}
state_updates["complex_evaluated"] = False
state_updates["af2_multimer_plddt"] = best_plddt
return {"tool_output": tool_output, "state_updates": state_updates}
complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb"
complex_pdb_filename = (
f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_"
f"pLDDT{best_plddt:.2f}.pdb"
)
complex_pdb_path = self.output_dir / complex_pdb_filename
try:
with open(complex_pdb_path, "w", encoding='utf-8') as f:
with open(complex_pdb_path, "w", encoding="utf-8") as f:
f.write(best_pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}")
logger.info(
f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) "
f"with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}"
)
state_updates["complex_pdb_content_path"] = str(complex_pdb_path)
state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_evaluated"] = True
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
complex_quality_message = (
f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) "
f"with pLDDT: {best_plddt:.2f}"
)
tool_output = {
"success": True,
"message": complex_quality_message,
"plddt": best_plddt,
"complex_file_path": str(complex_pdb_path),
"selected_model_index": best_model_idx
"selected_model_index": best_model_idx,
}
except IOError as e:
logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}")
tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"}
logger.error(
f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, "
f"pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}"
)
tool_output = {
"success": False,
"error": f"Failed to save best complex PDB: {e}",
}
state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_pdb_content_path"] = None
state_updates["complex_evaluated"] = True
return {"tool_output": tool_output, "state_updates": state_updates}
async def dispatch_tool_call(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict:
async def dispatch_tool_call(
self, tool_name: str, args: Dict, workflow_state: Dict
) -> Dict:
"""Main dispatch method for executing tools."""
item_id = workflow_state["item_id"]
internal_step = workflow_state["current_internal_step"]
logger.info(f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, Step {internal_step} with args: {args}")
logger.info(
f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, "
f"Step {internal_step} with args: {args}"
)
if not self.nim_api_key:
return {
"tool_output": {"success": False, "error": "NIM API key not configured in ToolExecutor."},
"state_updates": {}
"tool_output": {
"success": False,
"error": "NIM API key not configured in ToolExecutor.",
},
"state_updates": {},
}
if tool_name == "predict_target_structure_alphafold2":
@ -488,8 +708,13 @@ class ToolExecutor:
elif tool_name == "evaluate_binder_complex_alphafold2_multimer":
return await self._run_nim_af2_multimer(args, workflow_state)
else:
logger.error(f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}")
logger.error(
f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}"
)
return {
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"},
"state_updates": {}
"tool_output": {
"success": False,
"error": f"Unknown tool name: {tool_name}",
},
"state_updates": {},
}

View file

@ -2,4 +2,4 @@
from .pdb_utils import get_pdb_chain_details
__all__ = ["get_pdb_chain_details"]
__all__ = ["get_pdb_chain_details"]

View file

@ -1,13 +1,13 @@
import os
import logging
import yaml
from pathlib import Path
import os
from typing import Optional
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
def load_api_key() -> Optional[str]:
"""
Load the NVIDIA NIM API key from environment variables.
@ -17,8 +17,10 @@ def load_api_key() -> Optional[str]:
"""
api_key = os.environ.get("NVIDIA_NIM_API_KEY")
if not api_key:
logger.error("NVIDIA_NIM_API_KEY not found in environment variables. "
"Please set it in your .env file.")
logger.error(
"NVIDIA_NIM_API_KEY not found in environment variables. "
"Please set it in your .env file."
)
return None
return api_key

View file

@ -1,9 +1,12 @@
import logging
from typing import Dict, Tuple, List, Set, Union
from typing import Dict, Set, Tuple, Union
logger = logging.getLogger(__name__)
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
def get_pdb_chain_details(
pdb_content: str, preview_lines: int = 10
) -> Tuple[Dict[str, Dict[str, int]], str]:
"""
Parses PDB content to extract detailed information for each chain.
@ -27,7 +30,8 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
if line.startswith("ATOM"):
atom_lines.append(line)
chain_id = line[21:22].strip()
if not chain_id: chain_id = " "
if not chain_id:
chain_id = " "
atom_name = line[12:16].strip()
try:
residue_num = int(line[22:26].strip())
@ -39,7 +43,11 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
except ValueError:
logger.warning(f"Could not parse residue number from PDB line: {line}")
continue
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
elif (
line.startswith("HEADER")
or line.startswith("TITLE")
or line.startswith("COMPND")
):
header_lines.append(line)
chain_details: Dict[str, Dict[str, int]] = {}
@ -51,14 +59,16 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
chain_details[chain_id] = {
"min_residue": min_res,
"max_residue": max_res,
"length": length
"length": length,
}
else:
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
preview_str_parts = header_lines[: min(len(header_lines), preview_lines // 2)]
remaining_preview_lines = preview_lines - len(preview_str_parts)
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
preview_str_parts.extend(
atom_lines[: min(len(atom_lines), remaining_preview_lines)]
)
pdb_preview = "\n".join(preview_str_parts)
if len(pdb_content.splitlines()) > preview_lines:
pdb_preview += "\n..."

View file

@ -1 +0,0 @@
"""Protein design model API modules."""

View file

@ -1,67 +0,0 @@
PREDICT_TARGET_STRUCTURE_TOOL = {
"type": "function",
"function": {
"name": "predict_target_structure_alphafold2",
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
"parameters": {
"type": "object",
"properties": {
"sequence": {"type": "string", "description": "Amino acid sequence of the target protein."},
},
"required": ["sequence"]
}
}
}
DESIGN_BINDER_BACKBONE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_backbone_rfdiffusion",
"description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.",
"parameters": {
"type": "object",
"properties": {
"contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."},
"hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."},
},
"required": ["contigs"]
}
}
}
DESIGN_BINDER_SEQUENCE_TOOL = {
"type": "function",
"function": {
"name": "design_binder_sequence_proteinmpnn",
"description": "Designs an amino acid sequence for the generated binder backbone.",
"parameters": {
"type": "object",
"properties": {
"sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."}
},
"required": []
}
}
}
EVALUATE_COMPLEX_TOOL = {
"type": "function",
"function": {
"name": "evaluate_binder_complex_alphafold2_multimer",
"description": "Predicts the complex structure of target and designed binder, providing quality scores.",
"parameters": {
"type": "object",
"properties": {
"relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."}
},
"required": []
}
}
}
ALL_TOOLS_LIST = [
PREDICT_TARGET_STRUCTURE_TOOL,
DESIGN_BINDER_BACKBONE_TOOL,
DESIGN_BINDER_SEQUENCE_TOOL,
EVALUATE_COMPLEX_TOOL
]