mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linting
This commit is contained in:
parent
13a70e09ab
commit
54967ecae9
19 changed files with 1337 additions and 531 deletions
|
|
@ -2040,6 +2040,65 @@ python test_stl_env.py
|
|||
|
||||
---
|
||||
|
||||
### 23. Protein Design Environment (`protein_design/`)
|
||||
|
||||
**Contributors**: hallerite, promachina
|
||||
**PR**: [#70](https://github.com/NousResearch/atropos/pull/70)
|
||||
**Integration Status**: ✅ Integrated
|
||||
|
||||
**Description**: A comprehensive reinforcement learning environment for de novo protein design through a staged simulation loop. This environment enables AI systems to learn the complete protein design workflow from target structure prediction to binder evaluation, using state-of-the-art protein modeling tools.
|
||||
|
||||
**Core Features**:
|
||||
|
||||
**Multi-Stage Protein Design Pipeline**:
|
||||
- **AlphaFold2 Structure Prediction**: Predicts 3D structure of target proteins from amino acid sequences
|
||||
- **RFDiffusion Backbone Generation**: Generates novel protein binder backbones conditioned on target structures
|
||||
- **ProteinMPNN Sequence Design**: Designs optimal amino acid sequences for generated backbones
|
||||
- **AlphaFold2-Multimer Evaluation**: Evaluates binding complex quality with pLDDT scoring
|
||||
|
||||
**Advanced Workflow Management**:
|
||||
- **State-Based Progression**: Tracks workflow state through 4 distinct internal steps
|
||||
- **Retry Logic**: Configurable retry mechanisms for failed tool executions
|
||||
- **Validation Systems**: Comprehensive input validation for contigs, hotspots, and sequences
|
||||
- **Error Handling**: Robust error recovery and detailed logging
|
||||
|
||||
**NVIDIA NIM Integration**:
|
||||
- **API-Based Execution**: Leverages NVIDIA NIM APIs for protein modeling tools
|
||||
- **Async Processing**: Non-blocking API calls with configurable timeouts and polling
|
||||
- **Debug Mode**: Mock data generation for development and testing
|
||||
- **Result Caching**: Saves intermediate PDB files and FASTA sequences
|
||||
|
||||
**Reward System**:
|
||||
- **Format Rewards**: 0.2 points for correct tool usage in steps 0-2
|
||||
- **Quality Rewards**: pLDDT-based scoring (0.0-1.0) for final complex evaluation
|
||||
- **Progressive Scoring**: Cumulative rewards across workflow stages
|
||||
|
||||
**Data Management**:
|
||||
- **Hugging Face Integration**: Loads protein binding datasets (ronig/protein_binding_sequences)
|
||||
- **File Organization**: Structured output directory with timestamped results
|
||||
- **Comprehensive Logging**: Detailed workflow tracking and performance metrics
|
||||
|
||||
**Research Applications**:
|
||||
- **Drug Discovery**: Design novel protein binders for therapeutic targets
|
||||
- **Protein Engineering**: Optimize protein-protein interactions
|
||||
- **Structural Biology**: Explore protein design space systematically
|
||||
- **AI Training**: Develop protein design capabilities in language models
|
||||
|
||||
**Technical Requirements**:
|
||||
- NVIDIA NIM API access for protein modeling tools
|
||||
- Python environment with protein analysis libraries
|
||||
- Sufficient storage for PDB files and intermediate results
|
||||
|
||||
**Environment Configuration**:
|
||||
- Configurable retry limits and timeout settings
|
||||
- Debug mode for development without API calls
|
||||
- Flexible dataset selection and column mapping
|
||||
- WandB integration for experiment tracking
|
||||
|
||||
**Requirements**: pydantic, datasets, python-dotenv, pyyaml, wandb, atroposlib, nvidia-nim-api-client
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For questions or issues with community environments:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
# We use NVIDIA NIM to access hosted models on the API
|
||||
|
||||
NVIDIA_NIM_API_KEY: "YOUR API KEY"
|
||||
NVIDIA_NIM_API_KEY: "YOUR API KEY"
|
||||
|
|
@ -38,7 +38,7 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state
|
|||
### Step 4: Evaluate Binding (`AlphaFold-Multimer`)
|
||||
- **Input:** Target + binder sequences
|
||||
- **Output:** Complex structure prediction
|
||||
- **Reward:**
|
||||
- **Reward:**
|
||||
- Format OK
|
||||
- No steric clashes
|
||||
- **Bonus:** Contact interface, binding affinity metrics (Not yet implemented)
|
||||
|
|
@ -48,9 +48,9 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state
|
|||
## 🏆 Reward Function
|
||||
|
||||
The reward is cumulative:
|
||||
- **+0.2**: Successfully generate output in correct format at each step
|
||||
- **+0.2**: Successfully generate output in correct format at each step
|
||||
- **+0.0 to +1.0:** Structural reward based on complex validity smoothly interpolated on AlphaFold2 multimere confidence
|
||||
- **+1**: High predicted binding affinity (Not yet implemented)
|
||||
- **+1**: High predicted binding affinity (Not yet implemented)
|
||||
|
||||
Sparse, but real. LLMs must *plan* tool use, not just spam actions.
|
||||
|
||||
1
environments/community/protein_design/models/__init__.py
Normal file
1
environments/community/protein_design/models/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Protein design model API modules."""
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
||||
async def call_alphafold2(
|
||||
sequence: str,
|
||||
api_key: str,
|
||||
|
|
@ -24,7 +23,7 @@ async def call_alphafold2(
|
|||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 600,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 3,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2 API.
|
||||
|
|
@ -59,16 +58,13 @@ async def call_alphafold2(
|
|||
"iterations": iterations,
|
||||
"databases": databases,
|
||||
"relax_prediction": relax_prediction,
|
||||
"skip_template_search": skip_template_search
|
||||
"skip_template_search": skip_template_search,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
url, json=data, headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
|
|
@ -81,7 +77,7 @@ async def call_alphafold2(
|
|||
headers=headers,
|
||||
status_url=status_url,
|
||||
polling_interval=polling_interval,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
logger.error("No request ID in response headers")
|
||||
|
|
@ -93,16 +89,18 @@ async def call_alphafold2(
|
|||
return None
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error calling AlphaFold2 API: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
status_url: str,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
|
@ -121,18 +119,20 @@ async def _poll_job_status(
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{status_url}/{req_id}",
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
f"{status_url}/{req_id}", headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"AlphaFold2 job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
|
||||
logger.debug(
|
||||
f"AlphaFold2 job {req_id} still running, polling..."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
logger.error(f"Error checking AlphaFold2 job status: {response.status}")
|
||||
logger.error(
|
||||
f"Error checking AlphaFold2 job status: {response.status}"
|
||||
)
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
return None
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
||||
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
||||
"""
|
||||
Splits a string containing concatenated PDB file contents.
|
||||
|
|
@ -35,7 +35,9 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
return [pdb for pdb in pdbs if pdb]
|
||||
|
||||
|
||||
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
||||
def calculate_plddt_from_pdb_string(
|
||||
pdb_string: str,
|
||||
) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
||||
total_plddt = 0.0
|
||||
ca_atom_count = 0
|
||||
plddt_scores_per_ca: List[float] = []
|
||||
|
|
@ -67,10 +69,11 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
|
|||
average_plddt = total_plddt / ca_atom_count
|
||||
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
||||
|
||||
|
||||
async def _process_pdb_and_scores_from_api(
|
||||
pdb_contents: List[str],
|
||||
job_id: str,
|
||||
api_response_json: Optional[Dict[str, Any]] = None
|
||||
api_response_json: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Processes a list of PDB strings received from the API.
|
||||
|
|
@ -78,14 +81,19 @@ async def _process_pdb_and_scores_from_api(
|
|||
- Does NOT save files to disk.
|
||||
- Returns a dictionary containing a list of structures, each with its PDB content and scores.
|
||||
"""
|
||||
results: Dict[str, Any] = {
|
||||
"structures": []
|
||||
}
|
||||
results: Dict[str, Any] = {"structures": []}
|
||||
|
||||
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
|
||||
if (
|
||||
not pdb_contents
|
||||
or not isinstance(pdb_contents, list)
|
||||
or not all(isinstance(s, str) for s in pdb_contents)
|
||||
):
|
||||
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
|
||||
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No valid PDB content strings from API.",
|
||||
"structures": [],
|
||||
}
|
||||
|
||||
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
|
||||
|
||||
|
|
@ -94,12 +102,11 @@ async def _process_pdb_and_scores_from_api(
|
|||
logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.")
|
||||
continue
|
||||
|
||||
structure_data: Dict[str, Any] = {
|
||||
"model_index": i,
|
||||
"pdb_content": pdb_str
|
||||
}
|
||||
structure_data: Dict[str, Any] = {"model_index": i, "pdb_content": pdb_str}
|
||||
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = (
|
||||
calculate_plddt_from_pdb_string(pdb_str)
|
||||
)
|
||||
|
||||
structure_data["average_plddt"] = avg_plddt
|
||||
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue
|
||||
|
|
@ -116,13 +123,21 @@ async def _process_pdb_and_scores_from_api(
|
|||
results["structures"].append(structure_data)
|
||||
|
||||
if results["structures"]:
|
||||
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
|
||||
logger.info(
|
||||
f"Successfully processed and calculated pLDDTs for "
|
||||
f"{len(results['structures'])} structures for job {job_id}."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"No structures were processed for job {job_id}.")
|
||||
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []}
|
||||
return {
|
||||
"success": True,
|
||||
"message": "No PDB structures found in API response to process.",
|
||||
"structures": [],
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def call_alphafold2_multimer(
|
||||
sequences: List[str],
|
||||
api_key: str,
|
||||
|
|
@ -135,7 +150,7 @@ async def call_alphafold2_multimer(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 30,
|
||||
timeout: int = 3600
|
||||
timeout: int = 3600,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
||||
|
|
@ -155,7 +170,7 @@ async def call_alphafold2_multimer(
|
|||
"e_value": e_value,
|
||||
"iterations": iterations,
|
||||
"databases": databases,
|
||||
"relax_prediction": relax_prediction
|
||||
"relax_prediction": relax_prediction,
|
||||
}
|
||||
if selected_models is not None:
|
||||
data["selected_models"] = selected_models
|
||||
|
|
@ -165,10 +180,7 @@ async def call_alphafold2_multimer(
|
|||
initial_post_timeout = min(timeout, 600)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=initial_post_timeout
|
||||
url, json=data, headers=headers, timeout=initial_post_timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info("AlphaFold2-Multimer job completed synchronously.")
|
||||
|
|
@ -177,130 +189,239 @@ async def call_alphafold2_multimer(
|
|||
if "application/json" in content_type:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
|
||||
if (
|
||||
isinstance(api_response_json_payload, dict)
|
||||
and "error" in api_response_json_payload
|
||||
):
|
||||
logger.error(
|
||||
f"Sync API call returned error: "
|
||||
f"{api_response_json_payload['error']}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": api_response_json_payload["error"],
|
||||
"detail": api_response_json_payload.get(
|
||||
"detail", ""
|
||||
),
|
||||
}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Sync JSON response not a list of PDBs as expected.",
|
||||
}
|
||||
|
||||
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id_sync,
|
||||
api_response_json=None
|
||||
api_response_json=None,
|
||||
)
|
||||
else:
|
||||
err_text = await response.text()
|
||||
logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}")
|
||||
return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text}
|
||||
logger.error(
|
||||
f"Sync response unexpected content type: {content_type}. "
|
||||
f"Response: {err_text[:500]}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Sync response unexpected content type: {content_type}",
|
||||
"detail": err_text,
|
||||
}
|
||||
|
||||
elif response.status == 202:
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
|
||||
logger.info(
|
||||
f"AlphaFold2-Multimer job submitted, request ID: {req_id}"
|
||||
)
|
||||
return await _poll_job_status(
|
||||
req_id=req_id,
|
||||
headers=headers,
|
||||
status_url=status_url,
|
||||
polling_interval=polling_interval,
|
||||
overall_timeout=timeout
|
||||
overall_timeout=timeout,
|
||||
)
|
||||
else:
|
||||
logger.error("No request ID in 202 response headers")
|
||||
return {"success": False, "error": "No request ID in 202 response headers"}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No request ID in 202 response headers",
|
||||
}
|
||||
else:
|
||||
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
|
||||
logger.error(
|
||||
f"Error calling AlphaFold2-Multimer API (POST): {response.status}"
|
||||
)
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
return {"success": False, "error": f"Error calling API: {response.status}", "detail": text}
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Error calling API: {response.status}",
|
||||
"detail": text,
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
|
||||
logger.error("Timeout during AlphaFold2-Multimer API (initial POST).")
|
||||
return {"success": False, "error": "Timeout during initial API request"}
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Exception during AlphaFold2-Multimer API call (initial POST): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return {"success": False, "error": f"Exception during API call: {str(e)}"}
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
status_url: str,
|
||||
polling_interval: int = 30,
|
||||
overall_timeout: int = 3600
|
||||
overall_timeout: int = 3600,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
per_status_request_timeout = 600
|
||||
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
||||
logger.info(
|
||||
f"Polling job {req_id}. Individual status check timeout: "
|
||||
f"{per_status_request_timeout}s, Polling interval: {polling_interval}s, "
|
||||
f"Overall timeout: {overall_timeout}s"
|
||||
)
|
||||
|
||||
while True:
|
||||
current_loop_time = asyncio.get_event_loop().time()
|
||||
elapsed_time = current_loop_time - start_time
|
||||
|
||||
if elapsed_time >= overall_timeout:
|
||||
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
|
||||
logger.error(
|
||||
f"Overall polling timeout of {overall_timeout}s exceeded for "
|
||||
f"job {req_id}."
|
||||
)
|
||||
return {"success": False, "error": "Overall polling timeout exceeded."}
|
||||
|
||||
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
|
||||
current_status_check_timeout = min(per_status_request_timeout, remaining_time_for_overall_timeout)
|
||||
current_status_check_timeout = min(
|
||||
per_status_request_timeout, remaining_time_for_overall_timeout
|
||||
)
|
||||
|
||||
if current_status_check_timeout <= 0:
|
||||
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
|
||||
return {"success": False, "error": "Not enough time for status check within overall timeout."}
|
||||
logger.error(
|
||||
f"Not enough time left for another status check for job {req_id} "
|
||||
f"within overall_timeout."
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Not enough time for status check within overall timeout.",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
|
||||
logger.debug(
|
||||
f"Checking status for {req_id} with timeout "
|
||||
f"{current_status_check_timeout}s."
|
||||
)
|
||||
async with session.get(
|
||||
f"{status_url}/{req_id}",
|
||||
headers=headers,
|
||||
timeout=current_status_check_timeout
|
||||
timeout=current_status_check_timeout,
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
|
||||
if response.content_type == 'application/json':
|
||||
logger.info(
|
||||
f"AlphaFold2-Multimer job {req_id} completed (status 200)."
|
||||
)
|
||||
if response.content_type == "application/json":
|
||||
try:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
|
||||
return {"success": False, "error": "API response was not a list of PDB strings."}
|
||||
if (
|
||||
isinstance(api_response_json_payload, dict)
|
||||
and "error" in api_response_json_payload
|
||||
):
|
||||
logger.error(
|
||||
f"Job {req_id}: API returned error: "
|
||||
f"{api_response_json_payload['error']}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": api_response_json_payload["error"],
|
||||
"detail": api_response_json_payload.get(
|
||||
"detail", ""
|
||||
),
|
||||
}
|
||||
logger.error(
|
||||
f"Job {req_id}: Expected API response to be a list of PDB strings, "
|
||||
f"got {type(api_response_json_payload)}."
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": "API response was not a list of PDB strings.",
|
||||
}
|
||||
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id,
|
||||
api_response_json=None
|
||||
api_response_json=None,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
|
||||
logger.error(
|
||||
f"Job {req_id}: Failed to decode JSON response from API.",
|
||||
exc_info=True,
|
||||
)
|
||||
raw_text = await response.text()
|
||||
return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]}
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Failed to decode JSON response.",
|
||||
"detail": raw_text[:500],
|
||||
}
|
||||
else:
|
||||
raw_text = await response.text()
|
||||
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}")
|
||||
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text}
|
||||
logger.error(
|
||||
f"Job {req_id}: Unexpected content type {response.content_type}. "
|
||||
f"Expected application/json. Response: {raw_text[:500]}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unexpected content type: {response.content_type}",
|
||||
"detail": raw_text,
|
||||
}
|
||||
elif response.status == 202:
|
||||
try:
|
||||
job_status_json = await response.json()
|
||||
percent_complete = job_status_json.get('percentComplete', 'N/A')
|
||||
status_message = job_status_json.get('status', 'running')
|
||||
percent_complete = job_status_json.get(
|
||||
"percentComplete", "N/A"
|
||||
)
|
||||
status_message = job_status_json.get("status", "running")
|
||||
logger.debug(
|
||||
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
|
||||
f"Job {req_id} status: {status_message} ({percent_complete}% complete). "
|
||||
f"Polling again in {polling_interval}s."
|
||||
)
|
||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
||||
logger.debug(
|
||||
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
|
||||
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). "
|
||||
f"Polling again in {polling_interval}s."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
text = await response.text()
|
||||
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: HTTP {response.status} - {text}")
|
||||
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
|
||||
logger.error(
|
||||
f"Error checking AlphaFold2-Multimer job status {req_id}: "
|
||||
f"HTTP {response.status} - {text}"
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Status check failed with HTTP {response.status}",
|
||||
"detail": text,
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
|
||||
logger.warning(
|
||||
f"Client-side timeout ({current_status_check_timeout}s) during status check for "
|
||||
f"job {req_id}. Retrying poll after {polling_interval}s sleep."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
|
||||
logger.error(
|
||||
f"Client error polling job status for {req_id}: {e}. "
|
||||
f"Retrying poll after {polling_interval}s.",
|
||||
exc_info=True,
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Unexpected error polling job status {req_id}: {e}", exc_info=True
|
||||
)
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
||||
async def call_proteinmpnn(
|
||||
input_pdb: str,
|
||||
api_key: str,
|
||||
|
|
@ -20,7 +19,7 @@ async def call_proteinmpnn(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM ProteinMPNN API.
|
||||
|
|
@ -49,16 +48,13 @@ async def call_proteinmpnn(
|
|||
"input_pdb": input_pdb,
|
||||
"ca_only": ca_only,
|
||||
"use_soluble_model": use_soluble_model,
|
||||
"sampling_temp": sampling_temp
|
||||
"sampling_temp": sampling_temp,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
url, json=data, headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
|
|
@ -71,7 +67,7 @@ async def call_proteinmpnn(
|
|||
headers=headers,
|
||||
status_url=status_url,
|
||||
polling_interval=polling_interval,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
logger.error("No request ID in response headers")
|
||||
|
|
@ -85,12 +81,13 @@ async def call_proteinmpnn(
|
|||
logger.error(f"Error calling ProteinMPNN API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
status_url: str,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
|
@ -109,18 +106,20 @@ async def _poll_job_status(
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{status_url}/{req_id}",
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
f"{status_url}/{req_id}", headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"ProteinMPNN job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
|
||||
logger.debug(
|
||||
f"ProteinMPNN job {req_id} still running, polling..."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
logger.error(f"Error checking ProteinMPNN job status: {response.status}")
|
||||
logger.error(
|
||||
f"Error checking ProteinMPNN job status: {response.status}"
|
||||
)
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
return None
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
import os
|
||||
import logging
|
||||
import aiohttp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
|
||||
async def call_rfdiffusion(
|
||||
input_pdb: str,
|
||||
api_key: str,
|
||||
|
|
@ -20,7 +19,7 @@ async def call_rfdiffusion(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM RFDiffusion API.
|
||||
|
|
@ -45,10 +44,7 @@ async def call_rfdiffusion(
|
|||
"NVCF-POLL-SECONDS": "300",
|
||||
}
|
||||
|
||||
data = {
|
||||
"input_pdb": input_pdb,
|
||||
"diffusion_steps": diffusion_steps
|
||||
}
|
||||
data = {"input_pdb": input_pdb, "diffusion_steps": diffusion_steps}
|
||||
|
||||
if contigs:
|
||||
data["contigs"] = contigs
|
||||
|
|
@ -58,10 +54,7 @@ async def call_rfdiffusion(
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=data,
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
url, json=data, headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
|
|
@ -74,7 +67,7 @@ async def call_rfdiffusion(
|
|||
headers=headers,
|
||||
status_url=status_url,
|
||||
polling_interval=polling_interval,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
logger.error("No request ID in response headers")
|
||||
|
|
@ -88,12 +81,13 @@ async def call_rfdiffusion(
|
|||
logger.error(f"Error calling RFDiffusion API: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _poll_job_status(
|
||||
req_id: str,
|
||||
headers: Dict[str, str],
|
||||
status_url: str,
|
||||
polling_interval: int = 10,
|
||||
timeout: int = 60
|
||||
timeout: int = 60,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Poll the status endpoint until the job completes.
|
||||
|
|
@ -112,18 +106,20 @@ async def _poll_job_status(
|
|||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{status_url}/{req_id}",
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
f"{status_url}/{req_id}", headers=headers, timeout=timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"RFDiffusion job {req_id} completed")
|
||||
return await response.json()
|
||||
elif response.status == 202:
|
||||
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
|
||||
logger.debug(
|
||||
f"RFDiffusion job {req_id} still running, polling..."
|
||||
)
|
||||
await asyncio.sleep(polling_interval)
|
||||
else:
|
||||
logger.error(f"Error checking RFDiffusion job status: {response.status}")
|
||||
logger.error(
|
||||
f"Error checking RFDiffusion job status: {response.status}"
|
||||
)
|
||||
text = await response.text()
|
||||
logger.error(f"Response: {text}")
|
||||
return None
|
||||
|
|
@ -1,25 +1,25 @@
|
|||
import logging
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein.
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a specialized AI system for de novo protein design via a staged simulation loop. "
|
||||
"Your objective is to generate binder sequences that are structurally and functionally "
|
||||
"optimized to bind a given target protein.\n\n"
|
||||
"You will be guided through a multi-step pipeline:\n\n"
|
||||
"1. Structure prediction (AlphaFold)\n"
|
||||
"2. Binder backbone generation (RFdiffusion)\n"
|
||||
"3. Sequence design (ProteinMPNN)\n"
|
||||
"4. Structure evaluation (AlphaFold-Multimer)\n"
|
||||
"5. Feedback loop\n\n"
|
||||
"You must always:\n"
|
||||
"- Respect the required file format for each tool (e.g., FASTA, PDB).\n"
|
||||
"- Structure your outputs cleanly so they can be parsed and executed programmatically.\n"
|
||||
"- Be explicit in all configuration steps (e.g., contigs, hotspots).\n"
|
||||
"- Minimize ambiguity or verbosity; prefer concise and functional outputs.\n"
|
||||
"- Reason step-by-step when appropriate."
|
||||
)
|
||||
|
||||
You will be guided through a multi-step pipeline:
|
||||
|
||||
1. Structure prediction (AlphaFold)
|
||||
2. Binder backbone generation (RFdiffusion)
|
||||
3. Sequence design (ProteinMPNN)
|
||||
4. Structure evaluation (AlphaFold-Multimer)
|
||||
5. Feedback loop
|
||||
|
||||
You must always:
|
||||
- Respect the required file format for each tool (e.g., FASTA, PDB).
|
||||
- Structure your outputs cleanly so they can be parsed and executed programmatically.
|
||||
- Be explicit in all configuration steps (e.g., contigs, hotspots).
|
||||
- Minimize ambiguity or verbosity; prefer concise and functional outputs.
|
||||
- Reason step-by-step when appropriate.
|
||||
"""
|
||||
|
||||
def construct_user_prompt(state: dict) -> str:
|
||||
"""
|
||||
|
|
@ -42,19 +42,20 @@ def construct_user_prompt(state: dict) -> str:
|
|||
"You must provide the 'sequence' argument."
|
||||
)
|
||||
elif internal_step == 1:
|
||||
target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.")
|
||||
|
||||
chain_details = state.get("target_chain_details", {})
|
||||
if chain_details:
|
||||
chain_info_parts = []
|
||||
for chain_id, details in chain_details.items():
|
||||
min_r = details.get('min_residue', 'N/A')
|
||||
max_r = details.get('max_residue', 'N/A')
|
||||
l = details.get('length', 'N/A')
|
||||
chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
|
||||
min_r = details.get("min_residue", "N/A")
|
||||
max_r = details.get("max_residue", "N/A")
|
||||
length = details.get("length", "N/A")
|
||||
chain_info_parts.append(
|
||||
f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
|
||||
)
|
||||
chain_info_str = "\n- ".join(chain_info_parts)
|
||||
if chain_info_str:
|
||||
chain_info_str = "- " + chain_info_str
|
||||
chain_info_str = "- " + chain_info_str
|
||||
else:
|
||||
chain_info_str = "Chain information not available or PDB not yet processed."
|
||||
|
||||
|
|
@ -62,19 +63,26 @@ def construct_user_prompt(state: dict) -> str:
|
|||
f"The 3D structure of the target protein has been predicted.\n"
|
||||
f"Target Protein Chain Details:\n{chain_info_str}\n\n"
|
||||
"Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. "
|
||||
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. "
|
||||
"You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB "
|
||||
"and segments for the new binder. "
|
||||
"Examples:\n"
|
||||
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n"
|
||||
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
|
||||
"You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. "
|
||||
" - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: "
|
||||
"'A10-100/0 60'\n"
|
||||
" - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B "
|
||||
"from residue 60 to 100: 'B5-50/0 30 B60-100'\n"
|
||||
"You MUST use the chain IDs and residue ranges exactly as provided in the "
|
||||
"'Target Protein Chain Details' above. "
|
||||
"Do not invent chains or residue numbers outside these specified ranges for the target segments. "
|
||||
"For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n"
|
||||
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above."
|
||||
"Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist "
|
||||
"on the target as per the details above."
|
||||
)
|
||||
elif internal_step == 2:
|
||||
binder_pdb_content = state.get("binder_backbone_pdb_content")
|
||||
|
||||
binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.")
|
||||
binder_pdb_preview = state.get(
|
||||
"binder_pdb_preview", "Binder PDB preview not available."
|
||||
)
|
||||
binder_chain_info_str = "Binder chain information not available."
|
||||
|
||||
if binder_pdb_content:
|
||||
|
|
@ -83,10 +91,12 @@ def construct_user_prompt(state: dict) -> str:
|
|||
if binder_chain_details:
|
||||
chain_info_parts = []
|
||||
for cID, d_details in binder_chain_details.items():
|
||||
min_r = d_details.get('min_residue', 'N/A')
|
||||
max_r = d_details.get('max_residue', 'N/A')
|
||||
l = d_details.get('length', 'N/A')
|
||||
chain_info_parts.append(f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {l} amino acids)")
|
||||
min_r = d_details.get("min_residue", "N/A")
|
||||
max_r = d_details.get("max_residue", "N/A")
|
||||
length = d_details.get("length", "N/A")
|
||||
chain_info_parts.append(
|
||||
f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {length} amino acids)"
|
||||
)
|
||||
binder_chain_info_str = "\n- ".join(chain_info_parts)
|
||||
if binder_chain_info_str:
|
||||
binder_chain_info_str = "- " + binder_chain_info_str
|
||||
|
|
@ -99,7 +109,8 @@ def construct_user_prompt(state: dict) -> str:
|
|||
user_prompt_str = (
|
||||
f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n"
|
||||
f"Binder chain information:\n{binder_chain_info_str}.\n"
|
||||
"Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. "
|
||||
"Now, design an optimal amino acid sequence for this binder backbone using the "
|
||||
"'design_binder_sequence_proteinmpnn' tool. "
|
||||
"You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])."
|
||||
)
|
||||
elif internal_step == 3:
|
||||
|
|
@ -110,27 +121,39 @@ def construct_user_prompt(state: dict) -> str:
|
|||
if len(designed_binder_seq_data) == 1:
|
||||
binder_display_str = designed_binder_seq_data[0]
|
||||
else:
|
||||
binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \
|
||||
", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
|
||||
for i, s in enumerate(designed_binder_seq_data)])
|
||||
binder_display_str = (
|
||||
f"{len(designed_binder_seq_data)} chains: "
|
||||
+ ", ".join(
|
||||
[
|
||||
f"Chain {i+1} ({len(s)} aa): {s[:20]}..."
|
||||
for i, s in enumerate(designed_binder_seq_data)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif isinstance(designed_binder_seq_data, str):
|
||||
binder_display_str = designed_binder_seq_data
|
||||
binder_display_str = designed_binder_seq_data
|
||||
|
||||
user_prompt_str = (
|
||||
f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. "
|
||||
f"The original target sequence was: {target_sequence[:60]}...\n"
|
||||
"Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the "
|
||||
"'evaluate_binder_complex_alphafold2_multimer' tool. "
|
||||
"Finally, evaluate the binding complex of the original target protein and ALL chains of this "
|
||||
"designed binder using the 'evaluate_binder_complex_alphafold2_multimer' tool. "
|
||||
"You can optionally specify 'relax_prediction' (default is True)."
|
||||
)
|
||||
else:
|
||||
user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex."
|
||||
user_prompt_str = (
|
||||
"The protein design workflow is complete. No further actions required by you for this item. "
|
||||
"If successful, the key metric was the pLDDT of the complex."
|
||||
)
|
||||
|
||||
if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4:
|
||||
retry_prefix = "Your previous attempt at this step was not successful. "
|
||||
if state.get("previous_tool_error_message"):
|
||||
retry_prefix += f"Details: {state['previous_tool_error_message']}. "
|
||||
retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n"
|
||||
retry_prefix += (
|
||||
"Please review the requirements and PDB details carefully and try again to correctly use "
|
||||
"the expected tool.\n\n"
|
||||
)
|
||||
user_prompt_str = retry_prefix + user_prompt_str
|
||||
|
||||
return user_prompt_str
|
||||
File diff suppressed because it is too large
Load diff
92
environments/community/protein_design/tool_definitions.py
Normal file
92
environments/community/protein_design/tool_definitions.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
PREDICT_TARGET_STRUCTURE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "predict_target_structure_alphafold2",
|
||||
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sequence": {
|
||||
"type": "string",
|
||||
"description": "Amino acid sequence of the target protein.",
|
||||
},
|
||||
},
|
||||
"required": ["sequence"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DESIGN_BINDER_BACKBONE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_backbone_rfdiffusion",
|
||||
"description": (
|
||||
"Generates a novel protein binder backbone using RFDiffusion, "
|
||||
"conditioned on the target protein structure."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"contigs": {
|
||||
"type": "string",
|
||||
"description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70').",
|
||||
},
|
||||
"hotspot_residues": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional hotspot residues (e.g., ['A50', 'A52']).",
|
||||
},
|
||||
},
|
||||
"required": ["contigs"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
DESIGN_BINDER_SEQUENCE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_sequence_proteinmpnn",
|
||||
"description": "Designs an amino acid sequence for the generated binder backbone.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sampling_temp": {
|
||||
"type": "array",
|
||||
"items": {"type": "number"},
|
||||
"description": (
|
||||
"Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."
|
||||
),
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
EVALUATE_COMPLEX_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "evaluate_binder_complex_alphafold2_multimer",
|
||||
"description": (
|
||||
"Predicts the complex structure of target and designed binder, "
|
||||
"providing quality scores."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"relax_prediction": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to relax the prediction. Default True.",
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ALL_TOOLS_LIST = [
|
||||
PREDICT_TARGET_STRUCTURE_TOOL,
|
||||
DESIGN_BINDER_BACKBONE_TOOL,
|
||||
DESIGN_BINDER_SEQUENCE_TOOL,
|
||||
EVALUATE_COMPLEX_TOOL,
|
||||
]
|
||||
|
|
@ -1,19 +1,26 @@
|
|||
import logging
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||
from pathlib import Path
|
||||
from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2
|
||||
from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion
|
||||
from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn
|
||||
from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer
|
||||
from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .models.alphafold2 import call_alphafold2
|
||||
from .models.alphafold2_multimer import call_alphafold2_multimer
|
||||
from .models.proteinmpnn import call_proteinmpnn
|
||||
from .models.rfdiffusion import call_rfdiffusion
|
||||
from .utils.pdb_utils import get_pdb_chain_details
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(self, nim_api_key: str, api_timeout: int, polling_interval: int,
|
||||
output_dir: Path, debug_protein_design_calls: bool):
|
||||
def __init__(
|
||||
self,
|
||||
nim_api_key: str,
|
||||
api_timeout: int,
|
||||
polling_interval: int,
|
||||
output_dir: Path,
|
||||
debug_protein_design_calls: bool,
|
||||
):
|
||||
self.nim_api_key = nim_api_key
|
||||
self.api_timeout = api_timeout
|
||||
self.polling_interval = polling_interval
|
||||
|
|
@ -21,18 +28,23 @@ class ToolExecutor:
|
|||
self.debug_protein_design_calls = debug_protein_design_calls
|
||||
self._debug_af2m_call_count = 0
|
||||
|
||||
def _validate_rfd_contigs(self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
|
||||
def _validate_rfd_contigs(
|
||||
self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Validates the RFDiffusion contigs string against target PDB chain details.
|
||||
Returns None if valid, or an error message string if invalid.
|
||||
"""
|
||||
if not contigs_str: return "Contigs string is empty."
|
||||
if not contigs_str:
|
||||
return "Contigs string is empty."
|
||||
|
||||
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
|
||||
active_contig_parts = contigs_str.split('/') # Split by binder definition markers
|
||||
active_contig_parts = contigs_str.split(
|
||||
"/"
|
||||
) # Split by binder definition markers
|
||||
|
||||
for part in active_contig_parts:
|
||||
chain_segments_in_part = part.strip().split(' ')
|
||||
chain_segments_in_part = part.strip().split(" ")
|
||||
for segment_text in chain_segments_in_part:
|
||||
segment_text = segment_text.strip()
|
||||
if not segment_text or segment_text.isdigit():
|
||||
|
|
@ -45,42 +57,61 @@ class ToolExecutor:
|
|||
seg_end = int(seg_end_str)
|
||||
|
||||
if seg_chain_id not in target_chain_details:
|
||||
return f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. Valid: {list(target_chain_details.keys())}."
|
||||
return (
|
||||
f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. "
|
||||
f"Valid: {list(target_chain_details.keys())}."
|
||||
)
|
||||
|
||||
chain_min = target_chain_details[seg_chain_id]["min_residue"]
|
||||
chain_max = target_chain_details[seg_chain_id]["max_residue"]
|
||||
|
||||
if not (chain_min <= seg_start <= chain_max and chain_min <= seg_end <= chain_max and seg_start <= seg_end):
|
||||
return (f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
|
||||
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}.")
|
||||
if not (
|
||||
chain_min <= seg_start <= chain_max
|
||||
and chain_min <= seg_end <= chain_max
|
||||
and seg_start <= seg_end
|
||||
):
|
||||
return (
|
||||
f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' "
|
||||
f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}."
|
||||
)
|
||||
return None
|
||||
|
||||
def _validate_rfd_hotspots(self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]:
|
||||
def _validate_rfd_hotspots(
|
||||
self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details.
|
||||
Returns None if valid, or an error message string if invalid.
|
||||
"""
|
||||
if not hotspot_list: return None
|
||||
if not hotspot_list:
|
||||
return None
|
||||
|
||||
hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)")
|
||||
|
||||
for hotspot_str in hotspot_list:
|
||||
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
|
||||
match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip
|
||||
if not match:
|
||||
return f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
|
||||
return (
|
||||
f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')."
|
||||
)
|
||||
|
||||
hs_chain_id, hs_res_num_str = match.groups()
|
||||
hs_res_num = int(hs_res_num_str)
|
||||
|
||||
if hs_chain_id not in target_chain_details:
|
||||
return f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. Valid: {list(target_chain_details.keys())}."
|
||||
return (
|
||||
f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. "
|
||||
f"Valid: {list(target_chain_details.keys())}."
|
||||
)
|
||||
|
||||
chain_min = target_chain_details[hs_chain_id]["min_residue"]
|
||||
chain_max = target_chain_details[hs_chain_id]["max_residue"]
|
||||
|
||||
if not (chain_min <= hs_res_num <= chain_max):
|
||||
return (f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
|
||||
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}.")
|
||||
return (
|
||||
f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') "
|
||||
f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}."
|
||||
)
|
||||
return None
|
||||
|
||||
async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
|
|
@ -96,13 +127,18 @@ class ToolExecutor:
|
|||
state_updates = {}
|
||||
|
||||
if self.debug_protein_design_calls:
|
||||
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
|
||||
logger.warning(
|
||||
f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}."
|
||||
)
|
||||
module_dir = Path(__file__).parent
|
||||
fixed_pdb_path = module_dir / "debug_target.pdb"
|
||||
|
||||
if not fixed_pdb_path.exists():
|
||||
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
|
||||
tool_output = {"success": False, "error": f"Debug mode failed: Required file {fixed_pdb_path} not found."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": f"Debug mode failed: Required file {fixed_pdb_path} not found.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
with open(fixed_pdb_path, "r") as f:
|
||||
|
|
@ -115,7 +151,10 @@ class ToolExecutor:
|
|||
state_updates["target_pdb_preview"] = pdb_preview
|
||||
state_updates["target_structure_predicted"] = True
|
||||
|
||||
debug_pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
|
||||
debug_pdb_path = (
|
||||
self.output_dir
|
||||
/ f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb"
|
||||
)
|
||||
with open(debug_pdb_path, "w") as f:
|
||||
f.write(pdb_content)
|
||||
logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}")
|
||||
|
|
@ -124,25 +163,31 @@ class ToolExecutor:
|
|||
"success": True,
|
||||
"message": "DEBUG MODE: Used fixed PDB for AlphaFold2.",
|
||||
"target_pdb_preview": pdb_preview,
|
||||
"saved_pdb_path": str(debug_pdb_path)
|
||||
"saved_pdb_path": str(debug_pdb_path),
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
sequence_from_llm = args.get("sequence")
|
||||
if not sequence_from_llm:
|
||||
tool_output = {"success": False, "error": "Missing 'sequence' for AlphaFold2."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": "Missing 'sequence' for AlphaFold2.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
actual_sequence_to_use = target_sequence_from_state
|
||||
if sequence_from_llm != target_sequence_from_state:
|
||||
logger.warning(
|
||||
f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. "
|
||||
f"However, this tool will use the canonical target sequence from the workflow state: '{target_sequence_from_state[:20]}...'"
|
||||
f"However, this tool will use the canonical target sequence from the workflow state: "
|
||||
f"'{target_sequence_from_state[:20]}...'"
|
||||
)
|
||||
|
||||
api_result = await call_alphafold2(
|
||||
sequence=actual_sequence_to_use, api_key=self.nim_api_key,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
sequence=actual_sequence_to_use,
|
||||
api_key=self.nim_api_key,
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval,
|
||||
)
|
||||
if api_result and isinstance(api_result, list) and api_result[0]:
|
||||
pdb_content = api_result[0]
|
||||
|
|
@ -153,20 +198,34 @@ class ToolExecutor:
|
|||
state_updates["target_pdb_preview"] = pdb_preview
|
||||
state_updates["target_structure_predicted"] = True
|
||||
|
||||
pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
|
||||
with open(pdb_path, "w") as f: f.write(pdb_content)
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}")
|
||||
pdb_path = (
|
||||
self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb"
|
||||
)
|
||||
with open(pdb_path, "w") as f:
|
||||
f.write(pdb_content)
|
||||
logger.info(
|
||||
f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. "
|
||||
f"Chain details: {chain_details}"
|
||||
)
|
||||
|
||||
tool_output = {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview, "saved_pdb_path": str(pdb_path)}
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": "AlphaFold2 prediction complete.",
|
||||
"target_pdb_preview": pdb_preview,
|
||||
"saved_pdb_path": str(pdb_path),
|
||||
}
|
||||
else:
|
||||
error_detail = api_result.get("error", "AlphaFold2 prediction failed.") if isinstance(api_result, dict) else "AlphaFold2 prediction failed."
|
||||
error_detail = (
|
||||
api_result.get("error", "AlphaFold2 prediction failed.")
|
||||
if isinstance(api_result, dict)
|
||||
else "AlphaFold2 prediction failed."
|
||||
)
|
||||
logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["target_structure_predicted"] = False
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs RFDiffusion for binder backbone design. Returns structured output with
|
||||
|
|
@ -182,30 +241,55 @@ class ToolExecutor:
|
|||
|
||||
contigs_str_from_llm = args.get("contigs")
|
||||
if not target_pdb_content:
|
||||
tool_output = {"success": False, "error": "Target PDB not found in state for RFDiffusion."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": "Target PDB not found in state for RFDiffusion.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
if not contigs_str_from_llm:
|
||||
tool_output = {"success": False, "error": "Missing 'contigs' for RFDiffusion."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": "Missing 'contigs' for RFDiffusion.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
validation_error = self._validate_rfd_contigs(contigs_str_from_llm, target_chain_details)
|
||||
validation_error = self._validate_rfd_contigs(
|
||||
contigs_str_from_llm, target_chain_details
|
||||
)
|
||||
if validation_error:
|
||||
logger.warning(f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. Contigs: '{contigs_str_from_llm}'")
|
||||
tool_output = {"success": False, "error": f"Invalid contigs: {validation_error}"}
|
||||
logger.warning(
|
||||
f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. "
|
||||
f"Contigs: '{contigs_str_from_llm}'"
|
||||
)
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": f"Invalid contigs: {validation_error}",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
hotspot_residues = args.get("hotspot_residues")
|
||||
if hotspot_residues:
|
||||
hotspot_validation_error = self._validate_rfd_hotspots(hotspot_residues, target_chain_details)
|
||||
hotspot_validation_error = self._validate_rfd_hotspots(
|
||||
hotspot_residues, target_chain_details
|
||||
)
|
||||
if hotspot_validation_error:
|
||||
logger.warning(f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. Hotspots: {hotspot_residues}")
|
||||
tool_output = {"success": False, "error": f"Invalid hotspots: {hotspot_validation_error}"}
|
||||
logger.warning(
|
||||
f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. "
|
||||
f"Hotspots: {hotspot_residues}"
|
||||
)
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": f"Invalid hotspots: {hotspot_validation_error}",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
api_result = await call_rfdiffusion(
|
||||
input_pdb=target_pdb_content, api_key=self.nim_api_key,
|
||||
contigs=contigs_str_from_llm, hotspot_res=hotspot_residues,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
input_pdb=target_pdb_content,
|
||||
api_key=self.nim_api_key,
|
||||
contigs=contigs_str_from_llm,
|
||||
hotspot_res=hotspot_residues,
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval,
|
||||
)
|
||||
|
||||
if api_result and "output_pdb" in api_result:
|
||||
|
|
@ -217,20 +301,34 @@ class ToolExecutor:
|
|||
state_updates["binder_pdb_preview"] = binder_pdb_preview
|
||||
state_updates["binder_backbone_designed"] = True
|
||||
|
||||
pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
|
||||
with open(pdb_path, "w") as f: f.write(binder_pdb)
|
||||
pdb_path = (
|
||||
self.output_dir
|
||||
/ f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb"
|
||||
)
|
||||
with open(pdb_path, "w") as f:
|
||||
f.write(binder_pdb)
|
||||
logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}")
|
||||
|
||||
tool_output = {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb_preview, "saved_pdb_path": str(pdb_path)}
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": "RFDiffusion complete.",
|
||||
"binder_backbone_pdb_preview": binder_pdb_preview,
|
||||
"saved_pdb_path": str(pdb_path),
|
||||
}
|
||||
else:
|
||||
error_detail = api_result.get("error", "RFDiffusion failed.") if isinstance(api_result, dict) else "RFDiffusion failed."
|
||||
logger.error(f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}")
|
||||
error_detail = (
|
||||
api_result.get("error", "RFDiffusion failed.")
|
||||
if isinstance(api_result, dict)
|
||||
else "RFDiffusion failed."
|
||||
)
|
||||
logger.error(
|
||||
f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}"
|
||||
)
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["binder_backbone_designed"] = False
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs ProteinMPNN for binder sequence design. Returns structured output with
|
||||
|
|
@ -244,20 +342,33 @@ class ToolExecutor:
|
|||
state_updates = {}
|
||||
|
||||
if not binder_pdb:
|
||||
tool_output = {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": "Binder backbone PDB not found for ProteinMPNN.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
sampling_temp_list = args.get("sampling_temp", [0.1])
|
||||
|
||||
api_result = await call_proteinmpnn(
|
||||
input_pdb=binder_pdb, api_key=self.nim_api_key,
|
||||
input_pdb=binder_pdb,
|
||||
api_key=self.nim_api_key,
|
||||
sampling_temp=sampling_temp_list,
|
||||
timeout=self.api_timeout, polling_interval=self.polling_interval
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval,
|
||||
)
|
||||
|
||||
if not (api_result and "mfasta" in api_result):
|
||||
error_detail = api_result.get("error", "ProteinMPNN call failed or no mfasta in result.") if isinstance(api_result, dict) else "PMPNN call failed"
|
||||
logger.error(f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}")
|
||||
error_detail = (
|
||||
api_result.get(
|
||||
"error", "ProteinMPNN call failed or no mfasta in result."
|
||||
)
|
||||
if isinstance(api_result, dict)
|
||||
else "PMPNN call failed"
|
||||
)
|
||||
logger.error(
|
||||
f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}"
|
||||
)
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["binder_sequence_designed"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
|
@ -268,12 +379,15 @@ class ToolExecutor:
|
|||
current_sequence_parts: List[str] = []
|
||||
for line_content in fasta_content.splitlines():
|
||||
line = line_content.strip()
|
||||
if not line: continue
|
||||
if not line:
|
||||
continue
|
||||
if line.startswith(">"):
|
||||
if current_header and current_sequence_parts:
|
||||
full_sequence_line = "".join(current_sequence_parts)
|
||||
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
||||
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
||||
global_score = (
|
||||
float(score_match.group(1)) if score_match else -float("inf")
|
||||
)
|
||||
entries.append((global_score, current_header, full_sequence_line))
|
||||
current_header = line
|
||||
current_sequence_parts = []
|
||||
|
|
@ -282,7 +396,7 @@ class ToolExecutor:
|
|||
if current_header and current_sequence_parts:
|
||||
full_sequence_line = "".join(current_sequence_parts)
|
||||
score_match = re.search(r"global_score=([-\d.]+)", current_header)
|
||||
global_score = float(score_match.group(1)) if score_match else -float('inf')
|
||||
global_score = float(score_match.group(1)) if score_match else -float("inf")
|
||||
entries.append((global_score, current_header, full_sequence_line))
|
||||
|
||||
if not entries:
|
||||
|
|
@ -292,36 +406,59 @@ class ToolExecutor:
|
|||
|
||||
entries.sort(key=lambda x: x[0], reverse=True)
|
||||
best_global_score, best_header, best_full_sequence_line = entries[0]
|
||||
logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'")
|
||||
logger.info(
|
||||
f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) "
|
||||
f"from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'"
|
||||
)
|
||||
|
||||
parsed_binder_chains = [s.strip() for s in best_full_sequence_line.split('/') if s.strip()]
|
||||
parsed_binder_chains = [
|
||||
s.strip() for s in best_full_sequence_line.split("/") if s.strip()
|
||||
]
|
||||
|
||||
if not parsed_binder_chains or not all(s and s.isalpha() and s.isupper() for s in parsed_binder_chains):
|
||||
tool_output = {"success": False, "error": f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. Parsed: {parsed_binder_chains}"}
|
||||
if not parsed_binder_chains or not all(
|
||||
s and s.isalpha() and s.isupper() for s in parsed_binder_chains
|
||||
):
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. "
|
||||
f"Parsed: {parsed_binder_chains}"
|
||||
),
|
||||
}
|
||||
state_updates["binder_sequence_designed"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
state_updates["designed_binder_sequence"] = parsed_binder_chains
|
||||
state_updates["binder_sequence_designed"] = True
|
||||
|
||||
fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
|
||||
with open(fasta_path, "w") as f: f.write(fasta_content)
|
||||
logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}")
|
||||
fasta_path = (
|
||||
self.output_dir
|
||||
/ f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta"
|
||||
)
|
||||
with open(fasta_path, "w") as f:
|
||||
f.write(fasta_content)
|
||||
logger.info(
|
||||
f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. "
|
||||
f"Selected binder chains: {parsed_binder_chains}"
|
||||
)
|
||||
|
||||
preview = parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
|
||||
preview = (
|
||||
parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A"
|
||||
)
|
||||
if len(parsed_binder_chains) > 1:
|
||||
preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))"
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).",
|
||||
"message": (
|
||||
f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f})."
|
||||
),
|
||||
"designed_binder_sequence_list": parsed_binder_chains,
|
||||
"designed_binder_sequence_preview": preview,
|
||||
"saved_fasta_path": str(fasta_path)
|
||||
"saved_fasta_path": str(fasta_path),
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
|
|
@ -331,16 +468,31 @@ class ToolExecutor:
|
|||
tool_output = {}
|
||||
state_updates = {}
|
||||
|
||||
if not target_seq or not designed_binder_chains_list or not isinstance(designed_binder_chains_list, list):
|
||||
tool_output = {"success": False, "error": "Missing or invalid sequences for AF2-Multimer."}
|
||||
if (
|
||||
not target_seq
|
||||
or not designed_binder_chains_list
|
||||
or not isinstance(designed_binder_chains_list, list)
|
||||
):
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": "Missing or invalid sequences for AF2-Multimer.",
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list
|
||||
|
||||
for i, seq_to_validate in enumerate(all_input_sequences_for_multimer):
|
||||
if not (seq_to_validate and isinstance(seq_to_validate, str) and seq_to_validate.isalpha() and seq_to_validate.isupper()):
|
||||
error_msg = (f"Sequence {i+1} (part of target/binder complex) is invalid: "
|
||||
f"'{str(seq_to_validate)[:30]}...'. Contains non-alpha/lowercase, is empty, or not a string.")
|
||||
if not (
|
||||
seq_to_validate
|
||||
and isinstance(seq_to_validate, str)
|
||||
and seq_to_validate.isalpha()
|
||||
and seq_to_validate.isupper()
|
||||
):
|
||||
error_msg = (
|
||||
f"Sequence {i+1} (part of target/binder complex) is invalid: "
|
||||
f"'{str(seq_to_validate)[:30]}...'. "
|
||||
f"Contains non-alpha/lowercase, is empty, or not a string."
|
||||
)
|
||||
logger.error(f"Workflow {item_id}: {error_msg}")
|
||||
tool_output = {"success": False, "error": error_msg}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
|
@ -350,28 +502,41 @@ class ToolExecutor:
|
|||
if self.debug_protein_design_calls:
|
||||
self._debug_af2m_call_count += 1
|
||||
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
|
||||
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})"
|
||||
success_message = (
|
||||
f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality "
|
||||
f"mock results (call #{self._debug_af2m_call_count})"
|
||||
)
|
||||
|
||||
debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb"
|
||||
debug_pdb_path = self.output_dir / debug_pdb_filename
|
||||
try:
|
||||
with open(debug_pdb_path, "w") as f:
|
||||
f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
|
||||
logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}")
|
||||
f.write(
|
||||
f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n"
|
||||
)
|
||||
logger.info(
|
||||
f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}"
|
||||
)
|
||||
state_updates["complex_pdb_content_path"] = str(debug_pdb_path)
|
||||
except IOError as e:
|
||||
logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}")
|
||||
logger.error(
|
||||
f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}"
|
||||
)
|
||||
# If saving fails, don't set the path, but can still proceed with mock pLDDT
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
|
||||
|
||||
state_updates["af2_multimer_plddt"] = mock_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
tool_output = {
|
||||
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
|
||||
"success": True,
|
||||
"message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
|
||||
"plddt": mock_plddt,
|
||||
"complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None
|
||||
"complex_file_path": (
|
||||
str(debug_pdb_path)
|
||||
if state_updates["complex_pdb_content_path"]
|
||||
else None
|
||||
),
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
|
@ -380,36 +545,57 @@ class ToolExecutor:
|
|||
api_key=self.nim_api_key,
|
||||
relax_prediction=relax,
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval
|
||||
polling_interval=self.polling_interval,
|
||||
)
|
||||
|
||||
if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False):
|
||||
if api_result is None or (
|
||||
isinstance(api_result, dict) and api_result.get("success") is False
|
||||
):
|
||||
error_detail = "AF2-Multimer call failed or returned None."
|
||||
if isinstance(api_result, dict):
|
||||
error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.")
|
||||
error_detail = api_result.get(
|
||||
"error", "AF2-Multimer call failed with unspecified error."
|
||||
)
|
||||
detail_info = api_result.get("detail", "")
|
||||
if detail_info: error_detail += f" Details: {detail_info}"
|
||||
if detail_info:
|
||||
error_detail += f" Details: {detail_info}"
|
||||
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}")
|
||||
logger.error(
|
||||
f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. "
|
||||
f"API Result: {api_result}"
|
||||
)
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
all_structures_info = api_result.get("structures")
|
||||
if not all_structures_info or not isinstance(all_structures_info, list):
|
||||
message = api_result.get("message", "No structures returned from AF2-Multimer process.")
|
||||
message = api_result.get(
|
||||
"message", "No structures returned from AF2-Multimer process."
|
||||
)
|
||||
logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}")
|
||||
if not all_structures_info and isinstance(all_structures_info, list):
|
||||
tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None}
|
||||
state_updates["af2_multimer_plddt"] = 0.0
|
||||
state_updates["complex_evaluated"] = True
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": (
|
||||
"AF2-Multimer ran, but no PDB structures were produced by the API."
|
||||
),
|
||||
"plddt": 0.0,
|
||||
"complex_file_path": None,
|
||||
}
|
||||
state_updates["af2_multimer_plddt"] = 0.0
|
||||
state_updates["complex_evaluated"] = True
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
else:
|
||||
tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."}
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": (
|
||||
"AF2-Multimer returned unexpected data or no structures."
|
||||
),
|
||||
}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
best_structure_info = None
|
||||
highest_plddt = -1.0
|
||||
|
||||
|
|
@ -420,8 +606,14 @@ class ToolExecutor:
|
|||
best_structure_info = struct_info
|
||||
|
||||
if best_structure_info is None:
|
||||
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.")
|
||||
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
|
||||
logger.error(
|
||||
f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, "
|
||||
f"though structures were present."
|
||||
)
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": ("No valid structure with pLDDT in AF2-Multimer results."),
|
||||
}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
|
@ -430,53 +622,81 @@ class ToolExecutor:
|
|||
best_model_idx = best_structure_info.get("model_index", "NA")
|
||||
|
||||
if not best_pdb_content:
|
||||
logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.")
|
||||
tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."}
|
||||
logger.error(
|
||||
f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, "
|
||||
f"pLDDT {best_plddt:.2f}) found, but PDB content is missing."
|
||||
)
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content.",
|
||||
}
|
||||
state_updates["complex_evaluated"] = False
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb"
|
||||
complex_pdb_filename = (
|
||||
f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_"
|
||||
f"pLDDT{best_plddt:.2f}.pdb"
|
||||
)
|
||||
complex_pdb_path = self.output_dir / complex_pdb_filename
|
||||
|
||||
try:
|
||||
with open(complex_pdb_path, "w", encoding='utf-8') as f:
|
||||
with open(complex_pdb_path, "w", encoding="utf-8") as f:
|
||||
f.write(best_pdb_content)
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}")
|
||||
logger.info(
|
||||
f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) "
|
||||
f"with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}"
|
||||
)
|
||||
|
||||
state_updates["complex_pdb_content_path"] = str(complex_pdb_path)
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
|
||||
complex_quality_message = (
|
||||
f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) "
|
||||
f"with pLDDT: {best_plddt:.2f}"
|
||||
)
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": complex_quality_message,
|
||||
"plddt": best_plddt,
|
||||
"complex_file_path": str(complex_pdb_path),
|
||||
"selected_model_index": best_model_idx
|
||||
"selected_model_index": best_model_idx,
|
||||
}
|
||||
except IOError as e:
|
||||
logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}")
|
||||
tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"}
|
||||
logger.error(
|
||||
f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, "
|
||||
f"pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}"
|
||||
)
|
||||
tool_output = {
|
||||
"success": False,
|
||||
"error": f"Failed to save best complex PDB: {e}",
|
||||
}
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
async def dispatch_tool_call(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict:
|
||||
async def dispatch_tool_call(
|
||||
self, tool_name: str, args: Dict, workflow_state: Dict
|
||||
) -> Dict:
|
||||
"""Main dispatch method for executing tools."""
|
||||
item_id = workflow_state["item_id"]
|
||||
internal_step = workflow_state["current_internal_step"]
|
||||
logger.info(f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, Step {internal_step} with args: {args}")
|
||||
logger.info(
|
||||
f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, "
|
||||
f"Step {internal_step} with args: {args}"
|
||||
)
|
||||
|
||||
if not self.nim_api_key:
|
||||
return {
|
||||
"tool_output": {"success": False, "error": "NIM API key not configured in ToolExecutor."},
|
||||
"state_updates": {}
|
||||
"tool_output": {
|
||||
"success": False,
|
||||
"error": "NIM API key not configured in ToolExecutor.",
|
||||
},
|
||||
"state_updates": {},
|
||||
}
|
||||
|
||||
if tool_name == "predict_target_structure_alphafold2":
|
||||
|
|
@ -488,8 +708,13 @@ class ToolExecutor:
|
|||
elif tool_name == "evaluate_binder_complex_alphafold2_multimer":
|
||||
return await self._run_nim_af2_multimer(args, workflow_state)
|
||||
else:
|
||||
logger.error(f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}")
|
||||
logger.error(
|
||||
f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}"
|
||||
)
|
||||
return {
|
||||
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"},
|
||||
"state_updates": {}
|
||||
"tool_output": {
|
||||
"success": False,
|
||||
"error": f"Unknown tool name: {tool_name}",
|
||||
},
|
||||
"state_updates": {},
|
||||
}
|
||||
|
|
@ -2,4 +2,4 @@
|
|||
|
||||
from .pdb_utils import get_pdb_chain_details
|
||||
|
||||
__all__ = ["get_pdb_chain_details"]
|
||||
__all__ = ["get_pdb_chain_details"]
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
import os
|
||||
import logging
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_api_key() -> Optional[str]:
|
||||
"""
|
||||
Load the NVIDIA NIM API key from environment variables.
|
||||
|
|
@ -17,8 +17,10 @@ def load_api_key() -> Optional[str]:
|
|||
"""
|
||||
api_key = os.environ.get("NVIDIA_NIM_API_KEY")
|
||||
if not api_key:
|
||||
logger.error("NVIDIA_NIM_API_KEY not found in environment variables. "
|
||||
"Please set it in your .env file.")
|
||||
logger.error(
|
||||
"NVIDIA_NIM_API_KEY not found in environment variables. "
|
||||
"Please set it in your .env file."
|
||||
)
|
||||
return None
|
||||
|
||||
return api_key
|
||||
|
|
@ -1,9 +1,12 @@
|
|||
import logging
|
||||
from typing import Dict, Tuple, List, Set, Union
|
||||
from typing import Dict, Set, Tuple, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]:
|
||||
|
||||
def get_pdb_chain_details(
|
||||
pdb_content: str, preview_lines: int = 10
|
||||
) -> Tuple[Dict[str, Dict[str, int]], str]:
|
||||
"""
|
||||
Parses PDB content to extract detailed information for each chain.
|
||||
|
||||
|
|
@ -27,7 +30,8 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
|
|||
if line.startswith("ATOM"):
|
||||
atom_lines.append(line)
|
||||
chain_id = line[21:22].strip()
|
||||
if not chain_id: chain_id = " "
|
||||
if not chain_id:
|
||||
chain_id = " "
|
||||
atom_name = line[12:16].strip()
|
||||
try:
|
||||
residue_num = int(line[22:26].strip())
|
||||
|
|
@ -39,7 +43,11 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
|
|||
except ValueError:
|
||||
logger.warning(f"Could not parse residue number from PDB line: {line}")
|
||||
continue
|
||||
elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"):
|
||||
elif (
|
||||
line.startswith("HEADER")
|
||||
or line.startswith("TITLE")
|
||||
or line.startswith("COMPND")
|
||||
):
|
||||
header_lines.append(line)
|
||||
|
||||
chain_details: Dict[str, Dict[str, int]] = {}
|
||||
|
|
@ -51,14 +59,16 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di
|
|||
chain_details[chain_id] = {
|
||||
"min_residue": min_res,
|
||||
"max_residue": max_res,
|
||||
"length": length
|
||||
"length": length,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.")
|
||||
|
||||
preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)]
|
||||
preview_str_parts = header_lines[: min(len(header_lines), preview_lines // 2)]
|
||||
remaining_preview_lines = preview_lines - len(preview_str_parts)
|
||||
preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)])
|
||||
preview_str_parts.extend(
|
||||
atom_lines[: min(len(atom_lines), remaining_preview_lines)]
|
||||
)
|
||||
pdb_preview = "\n".join(preview_str_parts)
|
||||
if len(pdb_content.splitlines()) > preview_lines:
|
||||
pdb_preview += "\n..."
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Protein design model API modules."""
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
PREDICT_TARGET_STRUCTURE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "predict_target_structure_alphafold2",
|
||||
"description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sequence": {"type": "string", "description": "Amino acid sequence of the target protein."},
|
||||
},
|
||||
"required": ["sequence"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DESIGN_BINDER_BACKBONE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_backbone_rfdiffusion",
|
||||
"description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."},
|
||||
"hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."},
|
||||
},
|
||||
"required": ["contigs"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DESIGN_BINDER_SEQUENCE_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "design_binder_sequence_proteinmpnn",
|
||||
"description": "Designs an amino acid sequence for the generated binder backbone.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
EVALUATE_COMPLEX_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "evaluate_binder_complex_alphafold2_multimer",
|
||||
"description": "Predicts the complex structure of target and designed binder, providing quality scores.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ALL_TOOLS_LIST = [
|
||||
PREDICT_TARGET_STRUCTURE_TOOL,
|
||||
DESIGN_BINDER_BACKBONE_TOOL,
|
||||
DESIGN_BINDER_SEQUENCE_TOOL,
|
||||
EVALUATE_COMPLEX_TOOL
|
||||
]
|
||||
Loading…
Add table
Add a link
Reference in a new issue