From 54967ecae96d46a3d2ec85c59b0aee77f0b1bc45 Mon Sep 17 00:00:00 2001 From: Shannon Sands Date: Tue, 27 May 2025 12:15:15 +1000 Subject: [PATCH] linting --- environments/community/README.md | 59 ++ .../protein_design}/.env.example | 2 +- .../protein_design}/README.md | 6 +- .../configs/binderbench_default.yaml | 0 .../protein_design}/debug_target.pdb | 0 .../protein_design/models/__init__.py | 1 + .../protein_design}/models/alphafold2.py | 38 +- .../models/alphafold2_multimer.py | 261 +++++-- .../protein_design}/models/proteinmpnn.py | 37 +- .../protein_design}/models/rfdiffusion.py | 40 +- .../protein_design}/prompts.py | 107 +-- .../protein_design}/protein_env.py | 660 +++++++++++++----- .../protein_design/tool_definitions.py | 92 +++ .../protein_design}/tool_executor.py | 459 ++++++++---- .../protein_design}/utils/__init__.py | 2 +- .../protein_design}/utils/api_utils.py | 12 +- .../protein_design}/utils/pdb_utils.py | 24 +- .../protein_design_env/models/__init__.py | 1 - .../protein_design_env/tool_definitions.py | 67 -- 19 files changed, 1337 insertions(+), 531 deletions(-) rename environments/{hack0/protein_design_env => community/protein_design}/.env.example (61%) rename environments/{hack0/protein_design_env => community/protein_design}/README.md (96%) rename environments/{hack0/protein_design_env => community/protein_design}/configs/binderbench_default.yaml (100%) rename environments/{hack0/protein_design_env => community/protein_design}/debug_target.pdb (100%) create mode 100644 environments/community/protein_design/models/__init__.py rename environments/{hack0/protein_design_env => community/protein_design}/models/alphafold2.py (86%) rename environments/{hack0/protein_design_env => community/protein_design}/models/alphafold2_multimer.py (51%) rename environments/{hack0/protein_design_env => community/protein_design}/models/proteinmpnn.py (85%) rename environments/{hack0/protein_design_env => community/protein_design}/models/rfdiffusion.py (84%) rename environments/{hack0/protein_design_env => community/protein_design}/prompts.py (56%) rename environments/{hack0/protein_design_env => community/protein_design}/protein_env.py (52%) create mode 100644 environments/community/protein_design/tool_definitions.py rename environments/{hack0/protein_design_env => community/protein_design}/tool_executor.py (52%) rename environments/{hack0/protein_design_env => community/protein_design}/utils/__init__.py (74%) rename environments/{hack0/protein_design_env => community/protein_design}/utils/api_utils.py (71%) rename environments/{hack0/protein_design_env => community/protein_design}/utils/pdb_utils.py (80%) delete mode 100644 environments/hack0/protein_design_env/models/__init__.py delete mode 100644 environments/hack0/protein_design_env/tool_definitions.py diff --git a/environments/community/README.md b/environments/community/README.md index 979f70f1..9784ba0a 100644 --- a/environments/community/README.md +++ b/environments/community/README.md @@ -2040,6 +2040,65 @@ python test_stl_env.py --- +### 23. Protein Design Environment (`protein_design/`) + +**Contributors**: hallerite, promachina +**PR**: [#70](https://github.com/NousResearch/atropos/pull/70) +**Integration Status**: ✅ Integrated + +**Description**: A comprehensive reinforcement learning environment for de novo protein design through a staged simulation loop. This environment enables AI systems to learn the complete protein design workflow from target structure prediction to binder evaluation, using state-of-the-art protein modeling tools. + +**Core Features**: + +**Multi-Stage Protein Design Pipeline**: +- **AlphaFold2 Structure Prediction**: Predicts 3D structure of target proteins from amino acid sequences +- **RFDiffusion Backbone Generation**: Generates novel protein binder backbones conditioned on target structures +- **ProteinMPNN Sequence Design**: Designs optimal amino acid sequences for generated backbones +- **AlphaFold2-Multimer Evaluation**: Evaluates binding complex quality with pLDDT scoring + +**Advanced Workflow Management**: +- **State-Based Progression**: Tracks workflow state through 4 distinct internal steps +- **Retry Logic**: Configurable retry mechanisms for failed tool executions +- **Validation Systems**: Comprehensive input validation for contigs, hotspots, and sequences +- **Error Handling**: Robust error recovery and detailed logging + +**NVIDIA NIM Integration**: +- **API-Based Execution**: Leverages NVIDIA NIM APIs for protein modeling tools +- **Async Processing**: Non-blocking API calls with configurable timeouts and polling +- **Debug Mode**: Mock data generation for development and testing +- **Result Caching**: Saves intermediate PDB files and FASTA sequences + +**Reward System**: +- **Format Rewards**: 0.2 points for correct tool usage in steps 0-2 +- **Quality Rewards**: pLDDT-based scoring (0.0-1.0) for final complex evaluation +- **Progressive Scoring**: Cumulative rewards across workflow stages + +**Data Management**: +- **Hugging Face Integration**: Loads protein binding datasets (ronig/protein_binding_sequences) +- **File Organization**: Structured output directory with timestamped results +- **Comprehensive Logging**: Detailed workflow tracking and performance metrics + +**Research Applications**: +- **Drug Discovery**: Design novel protein binders for therapeutic targets +- **Protein Engineering**: Optimize protein-protein interactions +- **Structural Biology**: Explore protein design space systematically +- **AI Training**: Develop protein design capabilities in language models + +**Technical Requirements**: +- NVIDIA NIM API access for protein modeling tools +- Python environment with protein analysis libraries +- Sufficient storage for PDB files and intermediate results + +**Environment Configuration**: +- Configurable retry limits and timeout settings +- Debug mode for development without API calls +- Flexible dataset selection and column mapping +- WandB integration for experiment tracking + +**Requirements**: pydantic, datasets, python-dotenv, pyyaml, wandb, atroposlib, nvidia-nim-api-client + +--- + ## Support For questions or issues with community environments: diff --git a/environments/hack0/protein_design_env/.env.example b/environments/community/protein_design/.env.example similarity index 61% rename from environments/hack0/protein_design_env/.env.example rename to environments/community/protein_design/.env.example index 26fc5a46..bb47cfd0 100644 --- a/environments/hack0/protein_design_env/.env.example +++ b/environments/community/protein_design/.env.example @@ -1,3 +1,3 @@ # We use NVIDIA NIM to access hosted models on the API -NVIDIA_NIM_API_KEY: "YOUR API KEY" \ No newline at end of file +NVIDIA_NIM_API_KEY: "YOUR API KEY" diff --git a/environments/hack0/protein_design_env/README.md b/environments/community/protein_design/README.md similarity index 96% rename from environments/hack0/protein_design_env/README.md rename to environments/community/protein_design/README.md index 6f962c98..518316fa 100644 --- a/environments/hack0/protein_design_env/README.md +++ b/environments/community/protein_design/README.md @@ -38,7 +38,7 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state ### Step 4: Evaluate Binding (`AlphaFold-Multimer`) - **Input:** Target + binder sequences - **Output:** Complex structure prediction -- **Reward:** +- **Reward:** - Format OK - No steric clashes - **Bonus:** Contact interface, binding affinity metrics (Not yet implemented) @@ -48,9 +48,9 @@ Each episode consists of an LLM navigating a 4-step design pipeline, using state ## 🏆 Reward Function The reward is cumulative: -- **+0.2**: Successfully generate output in correct format at each step +- **+0.2**: Successfully generate output in correct format at each step - **+0.0 to +1.0:** Structural reward based on complex validity smoothly interpolated on AlphaFold2 multimere confidence -- **+1**: High predicted binding affinity (Not yet implemented) +- **+1**: High predicted binding affinity (Not yet implemented) Sparse, but real. LLMs must *plan* tool use, not just spam actions. diff --git a/environments/hack0/protein_design_env/configs/binderbench_default.yaml b/environments/community/protein_design/configs/binderbench_default.yaml similarity index 100% rename from environments/hack0/protein_design_env/configs/binderbench_default.yaml rename to environments/community/protein_design/configs/binderbench_default.yaml diff --git a/environments/hack0/protein_design_env/debug_target.pdb b/environments/community/protein_design/debug_target.pdb similarity index 100% rename from environments/hack0/protein_design_env/debug_target.pdb rename to environments/community/protein_design/debug_target.pdb diff --git a/environments/community/protein_design/models/__init__.py b/environments/community/protein_design/models/__init__.py new file mode 100644 index 00000000..05d40302 --- /dev/null +++ b/environments/community/protein_design/models/__init__.py @@ -0,0 +1 @@ +"""Protein design model API modules.""" diff --git a/environments/hack0/protein_design_env/models/alphafold2.py b/environments/community/protein_design/models/alphafold2.py similarity index 86% rename from environments/hack0/protein_design_env/models/alphafold2.py rename to environments/community/protein_design/models/alphafold2.py index af415ab1..4fae379e 100644 --- a/environments/hack0/protein_design_env/models/alphafold2.py +++ b/environments/community/protein_design/models/alphafold2.py @@ -1,16 +1,15 @@ -import os -import logging -import aiohttp -import json import asyncio -from typing import Dict, List, Any, Optional -from pathlib import Path +import logging +from typing import Any, Dict, List, Optional + +import aiohttp logger = logging.getLogger(__name__) DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + async def call_alphafold2( sequence: str, api_key: str, @@ -24,7 +23,7 @@ async def call_alphafold2( status_url: str = DEFAULT_STATUS_URL, polling_interval: int = 10, timeout: int = 600, - max_retries: int = 3 + max_retries: int = 3, ) -> Optional[Dict[str, Any]]: """ Call the NVIDIA NIM AlphaFold2 API. @@ -59,16 +58,13 @@ async def call_alphafold2( "iterations": iterations, "databases": databases, "relax_prediction": relax_prediction, - "skip_template_search": skip_template_search + "skip_template_search": skip_template_search, } try: async with aiohttp.ClientSession() as session: async with session.post( - url, - json=data, - headers=headers, - timeout=timeout + url, json=data, headers=headers, timeout=timeout ) as response: if response.status == 200: return await response.json() @@ -81,7 +77,7 @@ async def call_alphafold2( headers=headers, status_url=status_url, polling_interval=polling_interval, - timeout=timeout + timeout=timeout, ) else: logger.error("No request ID in response headers") @@ -93,16 +89,18 @@ async def call_alphafold2( return None except Exception as e: import traceback + logger.error(f"Error calling AlphaFold2 API: {e}") logger.error(traceback.format_exc()) return None + async def _poll_job_status( req_id: str, headers: Dict[str, str], status_url: str, polling_interval: int = 10, - timeout: int = 60 + timeout: int = 60, ) -> Optional[Dict[str, Any]]: """ Poll the status endpoint until the job completes. @@ -121,18 +119,20 @@ async def _poll_job_status( try: async with aiohttp.ClientSession() as session: async with session.get( - f"{status_url}/{req_id}", - headers=headers, - timeout=timeout + f"{status_url}/{req_id}", headers=headers, timeout=timeout ) as response: if response.status == 200: logger.info(f"AlphaFold2 job {req_id} completed") return await response.json() elif response.status == 202: - logger.debug(f"AlphaFold2 job {req_id} still running, polling...") + logger.debug( + f"AlphaFold2 job {req_id} still running, polling..." + ) await asyncio.sleep(polling_interval) else: - logger.error(f"Error checking AlphaFold2 job status: {response.status}") + logger.error( + f"Error checking AlphaFold2 job status: {response.status}" + ) text = await response.text() logger.error(f"Response: {text}") return None diff --git a/environments/hack0/protein_design_env/models/alphafold2_multimer.py b/environments/community/protein_design/models/alphafold2_multimer.py similarity index 51% rename from environments/hack0/protein_design_env/models/alphafold2_multimer.py rename to environments/community/protein_design/models/alphafold2_multimer.py index 25a855cb..7383d06f 100644 --- a/environments/hack0/protein_design_env/models/alphafold2_multimer.py +++ b/environments/community/protein_design/models/alphafold2_multimer.py @@ -1,16 +1,16 @@ -import os -import logging -import aiohttp -import json import asyncio -from typing import Dict, List, Any, Optional, Tuple -from pathlib import Path +import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +import aiohttp logger = logging.getLogger(__name__) DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + def _split_pdb_content(concatenated_pdb_str: str) -> List[str]: """ Splits a string containing concatenated PDB file contents. @@ -35,7 +35,9 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]: return [pdb for pdb in pdbs if pdb] -def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]: +def calculate_plddt_from_pdb_string( + pdb_string: str, +) -> Tuple[float, List[float], Dict[str, List[float]]]: total_plddt = 0.0 ca_atom_count = 0 plddt_scores_per_ca: List[float] = [] @@ -67,10 +69,11 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float] average_plddt = total_plddt / ca_atom_count return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain + async def _process_pdb_and_scores_from_api( pdb_contents: List[str], job_id: str, - api_response_json: Optional[Dict[str, Any]] = None + api_response_json: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: """ Processes a list of PDB strings received from the API. @@ -78,14 +81,19 @@ async def _process_pdb_and_scores_from_api( - Does NOT save files to disk. - Returns a dictionary containing a list of structures, each with its PDB content and scores. """ - results: Dict[str, Any] = { - "structures": [] - } + results: Dict[str, Any] = {"structures": []} - if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents): + if ( + not pdb_contents + or not isinstance(pdb_contents, list) + or not all(isinstance(s, str) for s in pdb_contents) + ): logger.warning(f"No valid PDB content strings provided for job {job_id}.") - return {"success": False, "error": "No valid PDB content strings from API.", "structures": []} - + return { + "success": False, + "error": "No valid PDB content strings from API.", + "structures": [], + } logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.") @@ -94,12 +102,11 @@ async def _process_pdb_and_scores_from_api( logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.") continue - structure_data: Dict[str, Any] = { - "model_index": i, - "pdb_content": pdb_str - } + structure_data: Dict[str, Any] = {"model_index": i, "pdb_content": pdb_str} - avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str) + avg_plddt, plddts_per_ca_residue, plddts_by_chain = ( + calculate_plddt_from_pdb_string(pdb_str) + ) structure_data["average_plddt"] = avg_plddt structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue @@ -116,13 +123,21 @@ async def _process_pdb_and_scores_from_api( results["structures"].append(structure_data) if results["structures"]: - logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.") + logger.info( + f"Successfully processed and calculated pLDDTs for " + f"{len(results['structures'])} structures for job {job_id}." + ) else: logger.warning(f"No structures were processed for job {job_id}.") - return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []} + return { + "success": True, + "message": "No PDB structures found in API response to process.", + "structures": [], + } return results + async def call_alphafold2_multimer( sequences: List[str], api_key: str, @@ -135,7 +150,7 @@ async def call_alphafold2_multimer( url: str = DEFAULT_URL, status_url: str = DEFAULT_STATUS_URL, polling_interval: int = 30, - timeout: int = 3600 + timeout: int = 3600, ) -> Optional[Dict[str, Any]]: """ Call the NVIDIA NIM AlphaFold2-Multimer API. @@ -155,7 +170,7 @@ async def call_alphafold2_multimer( "e_value": e_value, "iterations": iterations, "databases": databases, - "relax_prediction": relax_prediction + "relax_prediction": relax_prediction, } if selected_models is not None: data["selected_models"] = selected_models @@ -165,10 +180,7 @@ async def call_alphafold2_multimer( initial_post_timeout = min(timeout, 600) async with aiohttp.ClientSession() as session: async with session.post( - url, - json=data, - headers=headers, - timeout=initial_post_timeout + url, json=data, headers=headers, timeout=initial_post_timeout ) as response: if response.status == 200: logger.info("AlphaFold2-Multimer job completed synchronously.") @@ -177,130 +189,239 @@ async def call_alphafold2_multimer( if "application/json" in content_type: api_response_json_payload = await response.json() if not isinstance(api_response_json_payload, list): - if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload: - logger.error(f"Sync API call returned error: {api_response_json_payload['error']}") - return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")} - return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."} + if ( + isinstance(api_response_json_payload, dict) + and "error" in api_response_json_payload + ): + logger.error( + f"Sync API call returned error: " + f"{api_response_json_payload['error']}" + ) + return { + "success": False, + "error": api_response_json_payload["error"], + "detail": api_response_json_payload.get( + "detail", "" + ), + } + return { + "success": False, + "error": "Sync JSON response not a list of PDBs as expected.", + } req_id_sync = response.headers.get("nvcf-reqid", "sync_job") return await _process_pdb_and_scores_from_api( pdb_contents=api_response_json_payload, job_id=req_id_sync, - api_response_json=None + api_response_json=None, ) else: err_text = await response.text() - logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}") - return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text} + logger.error( + f"Sync response unexpected content type: {content_type}. " + f"Response: {err_text[:500]}" + ) + return { + "success": False, + "error": f"Sync response unexpected content type: {content_type}", + "detail": err_text, + } elif response.status == 202: req_id = response.headers.get("nvcf-reqid") if req_id: - logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}") + logger.info( + f"AlphaFold2-Multimer job submitted, request ID: {req_id}" + ) return await _poll_job_status( req_id=req_id, headers=headers, status_url=status_url, polling_interval=polling_interval, - overall_timeout=timeout + overall_timeout=timeout, ) else: logger.error("No request ID in 202 response headers") - return {"success": False, "error": "No request ID in 202 response headers"} + return { + "success": False, + "error": "No request ID in 202 response headers", + } else: - logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}") + logger.error( + f"Error calling AlphaFold2-Multimer API (POST): {response.status}" + ) text = await response.text() logger.error(f"Response: {text}") - return {"success": False, "error": f"Error calling API: {response.status}", "detail": text} + return { + "success": False, + "error": f"Error calling API: {response.status}", + "detail": text, + } except asyncio.TimeoutError: - logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).") + logger.error("Timeout during AlphaFold2-Multimer API (initial POST).") return {"success": False, "error": "Timeout during initial API request"} except Exception as e: - logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True) + logger.error( + f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", + exc_info=True, + ) return {"success": False, "error": f"Exception during API call: {str(e)}"} + async def _poll_job_status( req_id: str, headers: Dict[str, str], status_url: str, polling_interval: int = 30, - overall_timeout: int = 3600 + overall_timeout: int = 3600, ) -> Optional[Dict[str, Any]]: start_time = asyncio.get_event_loop().time() per_status_request_timeout = 600 - logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s") + logger.info( + f"Polling job {req_id}. Individual status check timeout: " + f"{per_status_request_timeout}s, Polling interval: {polling_interval}s, " + f"Overall timeout: {overall_timeout}s" + ) while True: current_loop_time = asyncio.get_event_loop().time() elapsed_time = current_loop_time - start_time if elapsed_time >= overall_timeout: - logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.") + logger.error( + f"Overall polling timeout of {overall_timeout}s exceeded for " + f"job {req_id}." + ) return {"success": False, "error": "Overall polling timeout exceeded."} remaining_time_for_overall_timeout = overall_timeout - elapsed_time - current_status_check_timeout = min(per_status_request_timeout, remaining_time_for_overall_timeout) + current_status_check_timeout = min( + per_status_request_timeout, remaining_time_for_overall_timeout + ) if current_status_check_timeout <= 0: - logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.") - return {"success": False, "error": "Not enough time for status check within overall timeout."} + logger.error( + f"Not enough time left for another status check for job {req_id} " + f"within overall_timeout." + ) + return { + "success": False, + "error": "Not enough time for status check within overall timeout.", + } try: async with aiohttp.ClientSession() as session: - logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.") + logger.debug( + f"Checking status for {req_id} with timeout " + f"{current_status_check_timeout}s." + ) async with session.get( f"{status_url}/{req_id}", headers=headers, - timeout=current_status_check_timeout + timeout=current_status_check_timeout, ) as response: if response.status == 200: - logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).") - if response.content_type == 'application/json': + logger.info( + f"AlphaFold2-Multimer job {req_id} completed (status 200)." + ) + if response.content_type == "application/json": try: api_response_json_payload = await response.json() if not isinstance(api_response_json_payload, list): - if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload: - logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}") - return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")} - logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.") - return {"success": False, "error": "API response was not a list of PDB strings."} + if ( + isinstance(api_response_json_payload, dict) + and "error" in api_response_json_payload + ): + logger.error( + f"Job {req_id}: API returned error: " + f"{api_response_json_payload['error']}" + ) + return { + "success": False, + "error": api_response_json_payload["error"], + "detail": api_response_json_payload.get( + "detail", "" + ), + } + logger.error( + f"Job {req_id}: Expected API response to be a list of PDB strings, " + f"got {type(api_response_json_payload)}." + ) + return { + "success": False, + "error": "API response was not a list of PDB strings.", + } return await _process_pdb_and_scores_from_api( pdb_contents=api_response_json_payload, job_id=req_id, - api_response_json=None + api_response_json=None, ) except json.JSONDecodeError: - logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True) + logger.error( + f"Job {req_id}: Failed to decode JSON response from API.", + exc_info=True, + ) raw_text = await response.text() - return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]} + return { + "success": False, + "error": "Failed to decode JSON response.", + "detail": raw_text[:500], + } else: raw_text = await response.text() - logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}") - return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text} + logger.error( + f"Job {req_id}: Unexpected content type {response.content_type}. " + f"Expected application/json. Response: {raw_text[:500]}" + ) + return { + "success": False, + "error": f"Unexpected content type: {response.content_type}", + "detail": raw_text, + } elif response.status == 202: try: job_status_json = await response.json() - percent_complete = job_status_json.get('percentComplete', 'N/A') - status_message = job_status_json.get('status', 'running') + percent_complete = job_status_json.get( + "percentComplete", "N/A" + ) + status_message = job_status_json.get("status", "running") logger.debug( - f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s." + f"Job {req_id} status: {status_message} ({percent_complete}% complete). " + f"Polling again in {polling_interval}s." ) except (aiohttp.ContentTypeError, json.JSONDecodeError): logger.debug( - f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s." + f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). " + f"Polling again in {polling_interval}s." ) await asyncio.sleep(polling_interval) else: text = await response.text() - logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: HTTP {response.status} - {text}") - return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text} + logger.error( + f"Error checking AlphaFold2-Multimer job status {req_id}: " + f"HTTP {response.status} - {text}" + ) + return { + "success": False, + "error": f"Status check failed with HTTP {response.status}", + "detail": text, + } except asyncio.TimeoutError: - logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.") + logger.warning( + f"Client-side timeout ({current_status_check_timeout}s) during status check for " + f"job {req_id}. Retrying poll after {polling_interval}s sleep." + ) await asyncio.sleep(polling_interval) except aiohttp.ClientError as e: - logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True) + logger.error( + f"Client error polling job status for {req_id}: {e}. " + f"Retrying poll after {polling_interval}s.", + exc_info=True, + ) await asyncio.sleep(polling_interval) except Exception as e: - logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True) + logger.error( + f"Unexpected error polling job status {req_id}: {e}", exc_info=True + ) return {"success": False, "error": f"Unexpected polling error: {str(e)}"} diff --git a/environments/hack0/protein_design_env/models/proteinmpnn.py b/environments/community/protein_design/models/proteinmpnn.py similarity index 85% rename from environments/hack0/protein_design_env/models/proteinmpnn.py rename to environments/community/protein_design/models/proteinmpnn.py index f91b0795..cb92cfcc 100644 --- a/environments/hack0/protein_design_env/models/proteinmpnn.py +++ b/environments/community/protein_design/models/proteinmpnn.py @@ -1,16 +1,15 @@ -import os -import logging -import aiohttp -import json import asyncio -from typing import Dict, List, Any, Optional, Union -from pathlib import Path +import logging +from typing import Any, Dict, List, Optional + +import aiohttp logger = logging.getLogger(__name__) DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + async def call_proteinmpnn( input_pdb: str, api_key: str, @@ -20,7 +19,7 @@ async def call_proteinmpnn( url: str = DEFAULT_URL, status_url: str = DEFAULT_STATUS_URL, polling_interval: int = 10, - timeout: int = 60 + timeout: int = 60, ) -> Optional[Dict[str, Any]]: """ Call the NVIDIA NIM ProteinMPNN API. @@ -49,16 +48,13 @@ async def call_proteinmpnn( "input_pdb": input_pdb, "ca_only": ca_only, "use_soluble_model": use_soluble_model, - "sampling_temp": sampling_temp + "sampling_temp": sampling_temp, } try: async with aiohttp.ClientSession() as session: async with session.post( - url, - json=data, - headers=headers, - timeout=timeout + url, json=data, headers=headers, timeout=timeout ) as response: if response.status == 200: return await response.json() @@ -71,7 +67,7 @@ async def call_proteinmpnn( headers=headers, status_url=status_url, polling_interval=polling_interval, - timeout=timeout + timeout=timeout, ) else: logger.error("No request ID in response headers") @@ -85,12 +81,13 @@ async def call_proteinmpnn( logger.error(f"Error calling ProteinMPNN API: {e}") return None + async def _poll_job_status( req_id: str, headers: Dict[str, str], status_url: str, polling_interval: int = 10, - timeout: int = 60 + timeout: int = 60, ) -> Optional[Dict[str, Any]]: """ Poll the status endpoint until the job completes. @@ -109,18 +106,20 @@ async def _poll_job_status( try: async with aiohttp.ClientSession() as session: async with session.get( - f"{status_url}/{req_id}", - headers=headers, - timeout=timeout + f"{status_url}/{req_id}", headers=headers, timeout=timeout ) as response: if response.status == 200: logger.info(f"ProteinMPNN job {req_id} completed") return await response.json() elif response.status == 202: - logger.debug(f"ProteinMPNN job {req_id} still running, polling...") + logger.debug( + f"ProteinMPNN job {req_id} still running, polling..." + ) await asyncio.sleep(polling_interval) else: - logger.error(f"Error checking ProteinMPNN job status: {response.status}") + logger.error( + f"Error checking ProteinMPNN job status: {response.status}" + ) text = await response.text() logger.error(f"Response: {text}") return None diff --git a/environments/hack0/protein_design_env/models/rfdiffusion.py b/environments/community/protein_design/models/rfdiffusion.py similarity index 84% rename from environments/hack0/protein_design_env/models/rfdiffusion.py rename to environments/community/protein_design/models/rfdiffusion.py index a88fd6d5..809c632e 100644 --- a/environments/hack0/protein_design_env/models/rfdiffusion.py +++ b/environments/community/protein_design/models/rfdiffusion.py @@ -1,16 +1,15 @@ -import os -import logging -import aiohttp -import json import asyncio -from typing import Dict, List, Any, Optional, Union -from pathlib import Path +import logging +from typing import Any, Dict, List, Optional + +import aiohttp logger = logging.getLogger(__name__) DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate" DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + async def call_rfdiffusion( input_pdb: str, api_key: str, @@ -20,7 +19,7 @@ async def call_rfdiffusion( url: str = DEFAULT_URL, status_url: str = DEFAULT_STATUS_URL, polling_interval: int = 10, - timeout: int = 60 + timeout: int = 60, ) -> Optional[Dict[str, Any]]: """ Call the NVIDIA NIM RFDiffusion API. @@ -45,10 +44,7 @@ async def call_rfdiffusion( "NVCF-POLL-SECONDS": "300", } - data = { - "input_pdb": input_pdb, - "diffusion_steps": diffusion_steps - } + data = {"input_pdb": input_pdb, "diffusion_steps": diffusion_steps} if contigs: data["contigs"] = contigs @@ -58,10 +54,7 @@ async def call_rfdiffusion( try: async with aiohttp.ClientSession() as session: async with session.post( - url, - json=data, - headers=headers, - timeout=timeout + url, json=data, headers=headers, timeout=timeout ) as response: if response.status == 200: return await response.json() @@ -74,7 +67,7 @@ async def call_rfdiffusion( headers=headers, status_url=status_url, polling_interval=polling_interval, - timeout=timeout + timeout=timeout, ) else: logger.error("No request ID in response headers") @@ -88,12 +81,13 @@ async def call_rfdiffusion( logger.error(f"Error calling RFDiffusion API: {e}") return None + async def _poll_job_status( req_id: str, headers: Dict[str, str], status_url: str, polling_interval: int = 10, - timeout: int = 60 + timeout: int = 60, ) -> Optional[Dict[str, Any]]: """ Poll the status endpoint until the job completes. @@ -112,18 +106,20 @@ async def _poll_job_status( try: async with aiohttp.ClientSession() as session: async with session.get( - f"{status_url}/{req_id}", - headers=headers, - timeout=timeout + f"{status_url}/{req_id}", headers=headers, timeout=timeout ) as response: if response.status == 200: logger.info(f"RFDiffusion job {req_id} completed") return await response.json() elif response.status == 202: - logger.debug(f"RFDiffusion job {req_id} still running, polling...") + logger.debug( + f"RFDiffusion job {req_id} still running, polling..." + ) await asyncio.sleep(polling_interval) else: - logger.error(f"Error checking RFDiffusion job status: {response.status}") + logger.error( + f"Error checking RFDiffusion job status: {response.status}" + ) text = await response.text() logger.error(f"Response: {text}") return None diff --git a/environments/hack0/protein_design_env/prompts.py b/environments/community/protein_design/prompts.py similarity index 56% rename from environments/hack0/protein_design_env/prompts.py rename to environments/community/protein_design/prompts.py index 8378e890..04b5c7f7 100644 --- a/environments/hack0/protein_design_env/prompts.py +++ b/environments/community/protein_design/prompts.py @@ -1,25 +1,25 @@ import logging -from typing import Dict logger = logging.getLogger(__name__) -SYSTEM_PROMPT = """You are a specialized AI system for de novo protein design via a staged simulation loop. Your objective is to generate binder sequences that are structurally and functionally optimized to bind a given target protein. +SYSTEM_PROMPT = ( + "You are a specialized AI system for de novo protein design via a staged simulation loop. " + "Your objective is to generate binder sequences that are structurally and functionally " + "optimized to bind a given target protein.\n\n" + "You will be guided through a multi-step pipeline:\n\n" + "1. Structure prediction (AlphaFold)\n" + "2. Binder backbone generation (RFdiffusion)\n" + "3. Sequence design (ProteinMPNN)\n" + "4. Structure evaluation (AlphaFold-Multimer)\n" + "5. Feedback loop\n\n" + "You must always:\n" + "- Respect the required file format for each tool (e.g., FASTA, PDB).\n" + "- Structure your outputs cleanly so they can be parsed and executed programmatically.\n" + "- Be explicit in all configuration steps (e.g., contigs, hotspots).\n" + "- Minimize ambiguity or verbosity; prefer concise and functional outputs.\n" + "- Reason step-by-step when appropriate." +) -You will be guided through a multi-step pipeline: - -1. Structure prediction (AlphaFold) -2. Binder backbone generation (RFdiffusion) -3. Sequence design (ProteinMPNN) -4. Structure evaluation (AlphaFold-Multimer) -5. Feedback loop - -You must always: -- Respect the required file format for each tool (e.g., FASTA, PDB). -- Structure your outputs cleanly so they can be parsed and executed programmatically. -- Be explicit in all configuration steps (e.g., contigs, hotspots). -- Minimize ambiguity or verbosity; prefer concise and functional outputs. -- Reason step-by-step when appropriate. -""" def construct_user_prompt(state: dict) -> str: """ @@ -42,19 +42,20 @@ def construct_user_prompt(state: dict) -> str: "You must provide the 'sequence' argument." ) elif internal_step == 1: - target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") chain_details = state.get("target_chain_details", {}) if chain_details: chain_info_parts = [] for chain_id, details in chain_details.items(): - min_r = details.get('min_residue', 'N/A') - max_r = details.get('max_residue', 'N/A') - l = details.get('length', 'N/A') - chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)") + min_r = details.get("min_residue", "N/A") + max_r = details.get("max_residue", "N/A") + length = details.get("length", "N/A") + chain_info_parts.append( + f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {length} amino acids)" + ) chain_info_str = "\n- ".join(chain_info_parts) if chain_info_str: - chain_info_str = "- " + chain_info_str + chain_info_str = "- " + chain_info_str else: chain_info_str = "Chain information not available or PDB not yet processed." @@ -62,19 +63,26 @@ def construct_user_prompt(state: dict) -> str: f"The 3D structure of the target protein has been predicted.\n" f"Target Protein Chain Details:\n{chain_info_str}\n\n" "Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. " - "You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. " + "You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB " + "and segments for the new binder. " "Examples:\n" - " - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n" - " - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n" - "You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. " + " - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: " + "'A10-100/0 60'\n" + " - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B " + "from residue 60 to 100: 'B5-50/0 30 B60-100'\n" + "You MUST use the chain IDs and residue ranges exactly as provided in the " + "'Target Protein Chain Details' above. " "Do not invent chains or residue numbers outside these specified ranges for the target segments. " "For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n" - "Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above." + "Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist " + "on the target as per the details above." ) elif internal_step == 2: binder_pdb_content = state.get("binder_backbone_pdb_content") - binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.") + binder_pdb_preview = state.get( + "binder_pdb_preview", "Binder PDB preview not available." + ) binder_chain_info_str = "Binder chain information not available." if binder_pdb_content: @@ -83,10 +91,12 @@ def construct_user_prompt(state: dict) -> str: if binder_chain_details: chain_info_parts = [] for cID, d_details in binder_chain_details.items(): - min_r = d_details.get('min_residue', 'N/A') - max_r = d_details.get('max_residue', 'N/A') - l = d_details.get('length', 'N/A') - chain_info_parts.append(f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {l} amino acids)") + min_r = d_details.get("min_residue", "N/A") + max_r = d_details.get("max_residue", "N/A") + length = d_details.get("length", "N/A") + chain_info_parts.append( + f"Chain {cID} (Residues: {min_r}-{max_r}, Length: {length} amino acids)" + ) binder_chain_info_str = "\n- ".join(chain_info_parts) if binder_chain_info_str: binder_chain_info_str = "- " + binder_chain_info_str @@ -99,7 +109,8 @@ def construct_user_prompt(state: dict) -> str: user_prompt_str = ( f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n" f"Binder chain information:\n{binder_chain_info_str}.\n" - "Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. " + "Now, design an optimal amino acid sequence for this binder backbone using the " + "'design_binder_sequence_proteinmpnn' tool. " "You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])." ) elif internal_step == 3: @@ -110,27 +121,39 @@ def construct_user_prompt(state: dict) -> str: if len(designed_binder_seq_data) == 1: binder_display_str = designed_binder_seq_data[0] else: - binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \ - ", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..." - for i, s in enumerate(designed_binder_seq_data)]) + binder_display_str = ( + f"{len(designed_binder_seq_data)} chains: " + + ", ".join( + [ + f"Chain {i+1} ({len(s)} aa): {s[:20]}..." + for i, s in enumerate(designed_binder_seq_data) + ] + ) + ) elif isinstance(designed_binder_seq_data, str): - binder_display_str = designed_binder_seq_data + binder_display_str = designed_binder_seq_data user_prompt_str = ( f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. " f"The original target sequence was: {target_sequence[:60]}...\n" - "Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the " - "'evaluate_binder_complex_alphafold2_multimer' tool. " + "Finally, evaluate the binding complex of the original target protein and ALL chains of this " + "designed binder using the 'evaluate_binder_complex_alphafold2_multimer' tool. " "You can optionally specify 'relax_prediction' (default is True)." ) else: - user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex." + user_prompt_str = ( + "The protein design workflow is complete. No further actions required by you for this item. " + "If successful, the key metric was the pLDDT of the complex." + ) if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: retry_prefix = "Your previous attempt at this step was not successful. " if state.get("previous_tool_error_message"): retry_prefix += f"Details: {state['previous_tool_error_message']}. " - retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n" + retry_prefix += ( + "Please review the requirements and PDB details carefully and try again to correctly use " + "the expected tool.\n\n" + ) user_prompt_str = retry_prefix + user_prompt_str return user_prompt_str diff --git a/environments/hack0/protein_design_env/protein_env.py b/environments/community/protein_design/protein_env.py similarity index 52% rename from environments/hack0/protein_design_env/protein_env.py rename to environments/community/protein_design/protein_env.py index 9dbd9eec..2bacd28a 100644 --- a/environments/hack0/protein_design_env/protein_env.py +++ b/environments/community/protein_design/protein_env.py @@ -1,32 +1,37 @@ -import asyncio import json import logging import os -import random -import re import uuid from pathlib import Path -from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict, Set +from typing import Any, Dict, List, Optional, Tuple, TypedDict import yaml -import wandb +from datasets import Dataset, load_dataset from dotenv import load_dotenv -from datasets import load_dataset, Dataset from pydantic import Field -from atroposlib.envs.base import BaseEnv, BaseEnvConfig, Item, APIServerConfig, ScoredDataGroup +import wandb +from atroposlib.envs.base import ( + APIServerConfig, + BaseEnv, + BaseEnvConfig, + Item, + ScoredDataGroup, +) from atroposlib.type_definitions import Message from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer -from environments.hack0.protein_design_env.tool_definitions import ALL_TOOLS_LIST -from environments.hack0.protein_design_env.tool_executor import ToolExecutor -from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details -from environments.hack0.protein_design_env.prompts import SYSTEM_PROMPT, construct_user_prompt +from .prompts import SYSTEM_PROMPT, construct_user_prompt +from .tool_definitions import ALL_TOOLS_LIST +from .tool_executor import ToolExecutor logger = logging.getLogger(__name__) load_dotenv() -def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str, split: str = "train") -> Dataset: + +def load_target_binder_pairs( + dataset_name: str, target_col: str, binder_col: str, split: str = "train" +) -> Dataset: """ Loads and transforms a Hugging Face dataset to contain only 'target' and 'binder' columns. @@ -46,41 +51,80 @@ def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str actual_binder_col = "peptide" try: - ds = ds.rename_columns({actual_target_col: "target", actual_binder_col: "binder"}) - ds = ds.remove_columns([col for col in ds.column_names if col not in {"target", "binder"}]) + ds = ds.rename_columns( + {actual_target_col: "target", actual_binder_col: "binder"} + ) + ds = ds.remove_columns( + [col for col in ds.column_names if col not in {"target", "binder"}] + ) except ValueError as e: logger.error(f"Error renaming columns: {e}") logger.error(f"Available columns: {ds.column_names}") - if actual_target_col in ds.column_names and actual_binder_col in ds.column_names: + if ( + actual_target_col in ds.column_names + and actual_binder_col in ds.column_names + ): ds = ds.select_columns([actual_target_col, actual_binder_col]) - ds = ds.rename_columns({actual_target_col: "target", actual_binder_col: "binder"}) + ds = ds.rename_columns( + {actual_target_col: "target", actual_binder_col: "binder"} + ) else: - logger.error(f"Could not find expected columns in dataset. Available columns: {ds.column_names}") - raise ValueError(f"Dataset {dataset_name} doesn't have the expected columns. Please check your dataset configuration.") + logger.error( + f"Could not find expected columns in dataset. Available columns: {ds.column_names}" + ) + raise ValueError( + f"Dataset {dataset_name} doesn't have the expected columns. Please check your dataset configuration." + ) return ds + class BinderRow(TypedDict): target: str binder: str + class BinderBenchConfig(BaseEnvConfig): nim_api_key: Optional[str] = Field(None, description="NVIDIA NIM API key") - nim_api_base_url: str = Field("https://health.api.nvidia.com/v1", description="NIM API base URL") + nim_api_base_url: str = Field( + "https://health.api.nvidia.com/v1", description="NIM API base URL" + ) api_timeout: int = Field(1800, description="Timeout for NIM API calls") polling_interval: int = Field(30, description="Polling interval for NIM jobs") - output_dir: str = Field(default=str(Path(__file__).parent / "outputs"), description="Directory to save PDBs, etc.") - debug_protein_design_calls: bool = Field(False, description="Enable debug mode for NIM protein API calls, returning mock data.") - max_retries_per_internal_step: int = Field(100, description="Max retries for a failed tool call within a workflow step (0 means no retries).") - dataset_name: str = Field("ronig/protein_binding_sequences", description="Dataset for target sequences") - target_col: str = Field("receptor", description="Target column name (actual column in the dataset)") - binder_col: str = Field("peptide", description="Binder column name (actual column in the dataset)") + output_dir: str = Field( + default=str(Path(__file__).parent / "outputs"), + description="Directory to save PDBs, etc.", + ) + debug_protein_design_calls: bool = Field( + False, + description="Enable debug mode for NIM protein API calls, returning mock data.", + ) + max_retries_per_internal_step: int = Field( + 100, + description="Max retries for a failed tool call within a workflow step (0 means no retries).", + ) + dataset_name: str = Field( + "ronig/protein_binding_sequences", description="Dataset for target sequences" + ) + target_col: str = Field( + "receptor", description="Target column name (actual column in the dataset)" + ) + binder_col: str = Field( + "peptide", description="Binder column name (actual column in the dataset)" + ) + class BinderBenchEnv(BaseEnv): name = "binderbench" env_config_cls = BinderBenchConfig - def __init__(self, config: BinderBenchConfig, server_configs: List[APIServerConfig], slurm=False, testing=False): + def __init__( + self, + config: BinderBenchConfig, + server_configs: List[APIServerConfig], + slurm=False, + testing=False, + ): super().__init__(config, server_configs, slurm, testing) self.config: BinderBenchConfig self.process_mode = False @@ -95,27 +139,35 @@ class BinderBenchEnv(BaseEnv): api_timeout=self.config.api_timeout, polling_interval=self.config.polling_interval, output_dir=self.output_dir, - debug_protein_design_calls=self.config.debug_protein_design_calls + debug_protein_design_calls=self.config.debug_protein_design_calls, ) - async def _execute_tool(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict: + async def _execute_tool( + self, tool_name: str, args: Dict, workflow_state: Dict + ) -> Dict: """Delegates tool execution and then updates workflow_state based on the result.""" - execution_result_package = await self.tool_executor.dispatch_tool_call(tool_name, args, workflow_state) + execution_result_package = await self.tool_executor.dispatch_tool_call( + tool_name, args, workflow_state + ) tool_output = execution_result_package.get("tool_output", {}) state_updates = execution_result_package.get("state_updates", {}) if state_updates: workflow_state.update(state_updates) - logger.debug(f"Workflow {workflow_state['item_id']}: State updated with keys: {list(state_updates.keys())}") + logger.debug( + f"Workflow {workflow_state['item_id']}: State updated with keys: {list(state_updates.keys())}" + ) return tool_output @classmethod def config_init(cls) -> Tuple[BinderBenchConfig, List[APIServerConfig]]: - default_yaml_path = Path(__file__).parent / "configs" / "binderbench_default.yaml" + default_yaml_path = ( + Path(__file__).parent / "configs" / "binderbench_default.yaml" + ) yaml_config_values = {} if default_yaml_path.exists(): - with open(default_yaml_path, 'r') as f: + with open(default_yaml_path, "r") as f: yaml_config_values = yaml.safe_load(f) or {} env_config = BinderBenchConfig( @@ -124,7 +176,7 @@ class BinderBenchEnv(BaseEnv): nim_api_key=os.environ.get("NVIDIA_NIM_API_KEY"), debug_protein_design_calls=yaml_config_values.get( "debug_protein_design_calls", - bool(os.environ.get("DEBUG_PROTEIN_DESIGN_CALLS", False)) + bool(os.environ.get("DEBUG_PROTEIN_DESIGN_CALLS", False)), ), ) @@ -135,7 +187,7 @@ class BinderBenchEnv(BaseEnv): APIServerConfig( model_name=os.environ.get("DEFAULT_LLM_MODEL", "gpt-4-turbo"), api_key=llm_api_key, - base_url=llm_base_url + base_url=llm_base_url, ) ] return env_config, server_configs @@ -143,18 +195,22 @@ class BinderBenchEnv(BaseEnv): async def setup(self): self.iter = 0 self.train = load_target_binder_pairs( - dataset_name=self.config.dataset_name, - target_col=self.config.target_col, - binder_col=self.config.binder_col - ) + dataset_name=self.config.dataset_name, + target_col=self.config.target_col, + binder_col=self.config.binder_col, + ) logger.info(f"Loaded {len(self.train)} target-binder pairs for {self.name}.") if not self.config.nim_api_key: self.config.nim_api_key = os.environ.get("NVIDIA_NIM_API_KEY") if not self.config.nim_api_key: - logger.warning("NVIDIA NIM API key not set. Protein design functions may not work properly.") + logger.warning( + "NVIDIA NIM API key not set. Protein design functions may not work properly." + ) - def _initialize_workflow_state(self, item_id: str, target_sequence: str, ground_truth_binder: Optional[str]) -> Dict: + def _initialize_workflow_state( + self, item_id: str, target_sequence: str, ground_truth_binder: Optional[str] + ) -> Dict: """Initializes or resets the state for a new workflow.""" return { "item_id": item_id, @@ -191,7 +247,9 @@ class BinderBenchEnv(BaseEnv): target_sequence = raw_item["target"] ground_truth_binder = raw_item.get("binder") - self.episodes_state[item_id] = self._initialize_workflow_state(item_id, target_sequence, ground_truth_binder) + self.episodes_state[item_id] = self._initialize_workflow_state( + item_id, target_sequence, ground_truth_binder + ) return item_id @@ -200,10 +258,14 @@ class BinderBenchEnv(BaseEnv): if item_id in self.episodes_state: return self.episodes_state[item_id] else: - logger.error(f"No state found for item_id {item_id}. Creating a default state.") + logger.error( + f"No state found for item_id {item_id}. Creating a default state." + ) return self._initialize_workflow_state(item_id, "", None) - async def collect_trajectories(self, item_id: str) -> Tuple[Optional[ScoredDataGroup], List[Item]]: + async def collect_trajectories( + self, item_id: str + ) -> Tuple[Optional[ScoredDataGroup], List[Item]]: workflow_state = self.episodes_state.get(item_id) if not workflow_state: logger.error(f"Workflow state for item_id {item_id} not found. Skipping.") @@ -213,69 +275,133 @@ class BinderBenchEnv(BaseEnv): logger.info(f"Workflow for {item_id} already marked complete. Skipping.") return None, [] - is_processing_mode = getattr(self, 'process_mode', False) # Check the flag + is_processing_mode = getattr(self, "process_mode", False) # Check the flag if is_processing_mode: all_turns_data_for_jsonl = [] MAX_INTERNAL_STEPS = 4 - while workflow_state["current_internal_step"] < MAX_INTERNAL_STEPS and \ - not workflow_state.get("workflow_complete_flag"): + while workflow_state[ + "current_internal_step" + ] < MAX_INTERNAL_STEPS and not workflow_state.get("workflow_complete_flag"): current_turn_messages: List[Message] = [] user_prompt_str = construct_user_prompt(workflow_state) - current_turn_messages.append(Message(role="system", content=SYSTEM_PROMPT)) - current_turn_messages.append(Message(role="user", content=user_prompt_str)) + current_turn_messages.append( + Message(role="system", content=SYSTEM_PROMPT) + ) + current_turn_messages.append( + Message(role="user", content=user_prompt_str) + ) llm_response = await self.server.chat_completion( - messages=current_turn_messages, tools=self.tools, tool_choice="auto", n=1, - max_tokens=self.config.max_token_length, temperature=0.5 + messages=current_turn_messages, + tools=self.tools, + tool_choice="auto", + n=1, + max_tokens=self.config.max_token_length, + temperature=0.5, ) assistant_message_obj = llm_response.choices[0].message assistant_content = assistant_message_obj.content or "" assistant_tool_calls = [] - if hasattr(assistant_message_obj, 'tool_calls') and assistant_message_obj.tool_calls: + if ( + hasattr(assistant_message_obj, "tool_calls") + and assistant_message_obj.tool_calls + ): assistant_tool_calls = [ - {"id": tc.id, "type": tc.type, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } for tc in assistant_message_obj.tool_calls ] - current_turn_messages.append(Message(role="assistant", content=assistant_content, tool_calls=assistant_tool_calls if assistant_tool_calls else None)) + current_turn_messages.append( + Message( + role="assistant", + content=assistant_content, + tool_calls=( + assistant_tool_calls if assistant_tool_calls else None + ), + ) + ) tool_error_for_retry_prompt = None if assistant_tool_calls: tool_call_request = assistant_tool_calls[0] tool_name = tool_call_request["function"]["name"] try: - tool_args = json.loads(tool_call_request["function"]["arguments"]) - tool_result = await self._execute_tool(tool_name, tool_args, workflow_state) - current_turn_messages.append(Message(role="tool", tool_call_id=tool_call_request["id"] , name=tool_name, content=json.dumps(tool_result))) - workflow_state["last_tool_success"] = tool_result.get("success", False) + tool_args = json.loads( + tool_call_request["function"]["arguments"] + ) + tool_result = await self._execute_tool( + tool_name, tool_args, workflow_state + ) + current_turn_messages.append( + Message( + role="tool", + tool_call_id=tool_call_request["id"], + name=tool_name, + content=json.dumps(tool_result), + ) + ) + workflow_state["last_tool_success"] = tool_result.get( + "success", False + ) if not workflow_state["last_tool_success"]: - tool_error_for_retry_prompt = tool_result.get("error", "Tool execution failed.") + tool_error_for_retry_prompt = tool_result.get( + "error", "Tool execution failed." + ) except Exception as e: error_msg = f"Error processing tool {tool_name}: {str(e)}" - current_turn_messages.append(Message(role="tool", tool_call_id=tool_call_request["id"], name=tool_name, content=error_msg)) + current_turn_messages.append( + Message( + role="tool", + tool_call_id=tool_call_request["id"], + name=tool_name, + content=error_msg, + ) + ) workflow_state["last_tool_success"] = False tool_error_for_retry_prompt = error_msg else: workflow_state["last_tool_success"] = False - expected_tool_name = {0:"AF2",1:"RFD",2:"PMPNN",3:"AF2M"}.get(workflow_state["current_internal_step"], "a tool") - tool_error_for_retry_prompt = f"No tool was called, but {expected_tool_name} was expected." + expected_tool_name = { + 0: "AF2", + 1: "RFD", + 2: "PMPNN", + 3: "AF2M", + }.get(workflow_state["current_internal_step"], "a tool") + tool_error_for_retry_prompt = ( + f"No tool was called, but {expected_tool_name} was expected." + ) - workflow_state["previous_tool_error_message"] = tool_error_for_retry_prompt + workflow_state["previous_tool_error_message"] = ( + tool_error_for_retry_prompt + ) - turn_score_details = self._score_trajectory(current_turn_messages, workflow_state) + turn_score_details = self._score_trajectory( + current_turn_messages, workflow_state + ) current_turn_reward = turn_score_details.get("overall_reward", 0.0) workflow_state["cumulative_reward"] += current_turn_reward - tokenization_result = tokenize_for_trainer(self.tokenizer, current_turn_messages, include_messages=False) - all_turns_data_for_jsonl.append({ - "tokens_this_turn": tokenization_result["tokens"], - "masks_this_turn": tokenization_result["masks"], - "score_this_turn": current_turn_reward, - "messages_this_turn": current_turn_messages.copy(), - "overrides_this_turn": turn_score_details.copy() - }) + tokenization_result = tokenize_for_trainer( + self.tokenizer, current_turn_messages, include_messages=False + ) + all_turns_data_for_jsonl.append( + { + "tokens_this_turn": tokenization_result["tokens"], + "masks_this_turn": tokenization_result["masks"], + "score_this_turn": current_turn_reward, + "messages_this_turn": current_turn_messages.copy(), + "overrides_this_turn": turn_score_details.copy(), + } + ) if workflow_state["last_tool_success"]: workflow_state["current_internal_step"] += 1 @@ -284,24 +410,42 @@ class BinderBenchEnv(BaseEnv): else: if workflow_state["current_internal_step"] <= 3: workflow_state["retry_count_this_internal_step"] += 1 - if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step: - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Max retries ({self.config.max_retries_per_internal_step}) reached. Terminating workflow for this item.") + if ( + workflow_state["retry_count_this_internal_step"] + > self.config.max_retries_per_internal_step + ): + logger.warning( + f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: " + f"Max retries ({self.config.max_retries_per_internal_step}) reached. " + f"Terminating workflow for this item." + ) workflow_state["workflow_complete_flag"] = True break else: - logger.info(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failed, attempt {workflow_state['retry_count_this_internal_step']}. Retrying same step.") + logger.info( + f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: " + f"Failed, attempt {workflow_state['retry_count_this_internal_step']}. " + f"Retrying same step." + ) else: - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at non-retryable step. Terminating workflow.") + logger.warning( + f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: " + f"Failure at non-retryable step. Terminating workflow." + ) workflow_state["workflow_complete_flag"] = True break if workflow_state["current_internal_step"] >= MAX_INTERNAL_STEPS: workflow_state["workflow_complete_flag"] = True - logger.info(f"Workflow {item_id}: All internal steps completed successfully.") + logger.info( + f"Workflow {item_id}: All internal steps completed successfully." + ) if not all_turns_data_for_jsonl: - logger.warning(f"Workflow {item_id} in process mode: No turn data collected.") + logger.warning( + f"Workflow {item_id} in process mode: No turn data collected." + ) return None, [] html_compatible_messages: List[str] = [] @@ -315,30 +459,55 @@ class BinderBenchEnv(BaseEnv): content_str = str(msg_obj.get("content", "[No Content]")) if msg_obj.get("tool_calls"): try: - tool_calls_str = json.dumps(msg_obj.get("tool_calls"), indent=2) + tool_calls_str = json.dumps( + msg_obj.get("tool_calls"), indent=2 + ) content_str += f"\nTool Calls:\n{tool_calls_str}" - except TypeError: # Handle non-serializable content if any - content_str += f"\nTool Calls: [Error serializing tool_calls]" - turn_str_parts.append(f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}") + except TypeError: # Handle non-serializable content if any + content_str += ( + "\nTool Calls: [Error serializing tool_calls]" + ) + turn_str_parts.append( + f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}" + ) else: turn_str_parts.append("No messages recorded for this turn.") html_compatible_messages.append("\n\n".join(turn_str_parts)) - turn_score = turn_data.get("overrides_this_turn", {}).get("overall_reward", 0.0) + turn_score = turn_data.get("overrides_this_turn", {}).get( + "overall_reward", 0.0 + ) html_compatible_scores.append(turn_score) overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {})) final_workflow_reward = workflow_state.get("cumulative_reward", 0.0) - if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"): - final_workflow_reward = all_turns_data_for_jsonl[-1].get("overrides_this_turn", {}).get("overall_reward", 0.0) + if workflow_state.get("complex_evaluated") and workflow_state.get( + "last_tool_success" + ): + final_workflow_reward = ( + all_turns_data_for_jsonl[-1] + .get("overrides_this_turn", {}) + .get("overall_reward", 0.0) + ) - all_tokens_per_turn = [turn_data["tokens_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("tokens_this_turn")] - all_masks_per_turn = [turn_data["masks_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("masks_this_turn")] + all_tokens_per_turn = [ + turn_data["tokens_this_turn"] + for turn_data in all_turns_data_for_jsonl + if turn_data.get("tokens_this_turn") + ] + all_masks_per_turn = [ + turn_data["masks_this_turn"] + for turn_data in all_turns_data_for_jsonl + if turn_data.get("masks_this_turn") + ] if len(all_tokens_per_turn) != len(html_compatible_messages): - logger.error(f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic.") + logger.error( + f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) " + f"and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic." + ) if all_turns_data_for_jsonl and all_tokens_per_turn: last_tokens = all_tokens_per_turn[-1] last_masks = all_masks_per_turn[-1] @@ -360,76 +529,152 @@ class BinderBenchEnv(BaseEnv): "is_process_mode_full_workflow": True, "final_score_for_workflow": final_workflow_reward, "target_sequence": workflow_state.get("target_sequence", "N/A"), - "designed_binder_sequence": workflow_state.get("designed_binder_sequence", "N/A"), - "final_plddt": workflow_state.get("af2_multimer_plddt", 0.0) - } + "designed_binder_sequence": workflow_state.get( + "designed_binder_sequence", "N/A" + ), + "final_plddt": workflow_state.get("af2_multimer_plddt", 0.0), + }, ) await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy()) self.completed_episode_metrics.append(workflow_state.copy()) - if item_id in self.episodes_state: del self.episodes_state[item_id] + if item_id in self.episodes_state: + del self.episodes_state[item_id] return process_mode_scored_data, [] else: current_turn_messages_serve: List[Message] = [] user_prompt_str_serve = construct_user_prompt(workflow_state) - current_turn_messages_serve.append(Message(role="system", content=SYSTEM_PROMPT)) - current_turn_messages_serve.append(Message(role="user", content=user_prompt_str_serve)) + current_turn_messages_serve.append( + Message(role="system", content=SYSTEM_PROMPT) + ) + current_turn_messages_serve.append( + Message(role="user", content=user_prompt_str_serve) + ) llm_response_serve = await self.server.chat_completion( - messages=current_turn_messages_serve, tools=self.tools, tool_choice="auto", n=1, - max_tokens=self.config.max_token_length, temperature=0.5 + messages=current_turn_messages_serve, + tools=self.tools, + tool_choice="auto", + n=1, + max_tokens=self.config.max_token_length, + temperature=0.5, ) assistant_message_obj_serve = llm_response_serve.choices[0].message assistant_content_serve = assistant_message_obj_serve.content or "" assistant_tool_calls_serve = [] - if hasattr(assistant_message_obj_serve, 'tool_calls') and assistant_message_obj_serve.tool_calls: + if ( + hasattr(assistant_message_obj_serve, "tool_calls") + and assistant_message_obj_serve.tool_calls + ): assistant_tool_calls_serve = [ - {"id": tc.id, "type": tc.type, "function": {"name": tc.function.name, "arguments": tc.function.arguments}} + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } for tc in assistant_message_obj_serve.tool_calls ] - current_turn_messages_serve.append(Message(role="assistant", content=assistant_content_serve, tool_calls=assistant_tool_calls_serve if assistant_tool_calls_serve else None)) + current_turn_messages_serve.append( + Message( + role="assistant", + content=assistant_content_serve, + tool_calls=( + assistant_tool_calls_serve + if assistant_tool_calls_serve + else None + ), + ) + ) tool_error_for_retry_prompt_serve = None if assistant_tool_calls_serve: tool_call_request_serve = assistant_tool_calls_serve[0] tool_name_serve = tool_call_request_serve["function"]["name"] try: - tool_args_json_str = tool_call_request_serve["function"]["arguments"] + tool_args_json_str = tool_call_request_serve["function"][ + "arguments" + ] tool_args_serve = json.loads(tool_args_json_str) - tool_result_serve = await self._execute_tool(tool_name_serve, tool_args_serve, workflow_state) - current_turn_messages_serve.append(Message(role="tool", tool_call_id=tool_call_request_serve["id"] , name=tool_name_serve, content=json.dumps(tool_result_serve))) - workflow_state["last_tool_success"] = tool_result_serve.get("success", False) + tool_result_serve = await self._execute_tool( + tool_name_serve, tool_args_serve, workflow_state + ) + current_turn_messages_serve.append( + Message( + role="tool", + tool_call_id=tool_call_request_serve["id"], + name=tool_name_serve, + content=json.dumps(tool_result_serve), + ) + ) + workflow_state["last_tool_success"] = tool_result_serve.get( + "success", False + ) if not workflow_state["last_tool_success"]: - tool_error_for_retry_prompt_serve = tool_result_serve.get("error", "Tool execution failed.") + tool_error_for_retry_prompt_serve = tool_result_serve.get( + "error", "Tool execution failed." + ) except Exception as e: - error_msg_serve = f"Error processing tool {tool_name_serve}: {str(e)}" - current_turn_messages_serve.append(Message(role="tool", tool_call_id=tool_call_request_serve["id"], name=tool_name_serve, content=error_msg_serve)) + error_msg_serve = ( + f"Error processing tool {tool_name_serve}: {str(e)}" + ) + current_turn_messages_serve.append( + Message( + role="tool", + tool_call_id=tool_call_request_serve["id"], + name=tool_name_serve, + content=error_msg_serve, + ) + ) workflow_state["last_tool_success"] = False tool_error_for_retry_prompt_serve = error_msg_serve else: workflow_state["last_tool_success"] = False - expected_tool_name_serve = {0:"AF2",1:"RFD",2:"PMPNN",3:"AF2M"}.get(workflow_state["current_internal_step"], "a tool") - tool_error_for_retry_prompt_serve = f"No tool was called, but {expected_tool_name_serve} was expected." + expected_tool_name_serve = { + 0: "AF2", + 1: "RFD", + 2: "PMPNN", + 3: "AF2M", + }.get(workflow_state["current_internal_step"], "a tool") + tool_error_for_retry_prompt_serve = ( + f"No tool was called, but {expected_tool_name_serve} was expected." + ) - workflow_state["previous_tool_error_message"] = tool_error_for_retry_prompt_serve + workflow_state["previous_tool_error_message"] = ( + tool_error_for_retry_prompt_serve + ) - turn_score_details_serve = self._score_trajectory(current_turn_messages_serve, workflow_state) - current_turn_reward_serve = turn_score_details_serve.get("overall_reward", 0.0) + turn_score_details_serve = self._score_trajectory( + current_turn_messages_serve, workflow_state + ) + current_turn_reward_serve = turn_score_details_serve.get( + "overall_reward", 0.0 + ) workflow_state["cumulative_reward"] += current_turn_reward_serve - workflow_state["turn_messages_history"].append(current_turn_messages_serve.copy()) + workflow_state["turn_messages_history"].append( + current_turn_messages_serve.copy() + ) tokenization_result_serve = tokenize_for_trainer( - self.tokenizer, current_turn_messages_serve, include_messages=self.config.include_messages + self.tokenizer, + current_turn_messages_serve, + include_messages=self.config.include_messages, ) scored_data_serve = ScoredDataGroup( tokens=[tokenization_result_serve["tokens"]], masks=[tokenization_result_serve["masks"]], scores=[current_turn_reward_serve], - messages=[current_turn_messages_serve] if self.config.include_messages else None, + messages=( + [current_turn_messages_serve] + if self.config.include_messages + else None + ), overrides=[turn_score_details_serve], - group_overrides={"group_size": 1} + group_overrides={"group_size": 1}, ) backlog_items_serve = [] @@ -440,36 +685,59 @@ class BinderBenchEnv(BaseEnv): else: if workflow_state["current_internal_step"] <= 3: workflow_state["retry_count_this_internal_step"] += 1 - if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step: - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Max retries reached. Terminating.") + if ( + workflow_state["retry_count_this_internal_step"] + > self.config.max_retries_per_internal_step + ): + logger.warning( + f"Workflow {item_id}, Step {workflow_state['current_internal_step']} " + f"(Serve Mode): Max retries reached. Terminating." + ) workflow_state["workflow_complete_flag"] = True else: - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at non-retryable step. Terminating.") + logger.warning( + f"Workflow {item_id}, Step {workflow_state['current_internal_step']} " + f"(Serve Mode): Failure at non-retryable step. Terminating." + ) workflow_state["workflow_complete_flag"] = True - if workflow_state["current_internal_step"] < 4 and not workflow_state.get("workflow_complete_flag"): + if workflow_state["current_internal_step"] < 4 and not workflow_state.get( + "workflow_complete_flag" + ): should_add_to_backlog = False if workflow_state["last_tool_success"]: should_add_to_backlog = True - elif workflow_state["current_internal_step"] <= 3 and \ - workflow_state["retry_count_this_internal_step"] <= self.config.max_retries_per_internal_step: + elif ( + workflow_state["current_internal_step"] <= 3 + and workflow_state["retry_count_this_internal_step"] + <= self.config.max_retries_per_internal_step + ): should_add_to_backlog = True if should_add_to_backlog: backlog_items_serve.append(item_id) else: workflow_state["workflow_complete_flag"] = True - logger.info(f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. Internal step: {workflow_state['current_internal_step']}") + logger.info( + f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. " + f"Internal step: {workflow_state['current_internal_step']}" + ) if workflow_state.get("workflow_complete_flag"): if item_id in self.episodes_state: - await self.add_rollouts_for_wandb(data_for_log=self.episodes_state[item_id].copy()) - self.completed_episode_metrics.append(self.episodes_state[item_id].copy()) + await self.add_rollouts_for_wandb( + data_for_log=self.episodes_state[item_id].copy() + ) + self.completed_episode_metrics.append( + self.episodes_state[item_id].copy() + ) del self.episodes_state[item_id] return scored_data_serve, backlog_items_serve - def _score_trajectory(self, turn_messages: List[Message], workflow_state: Dict) -> Dict[str, float]: + def _score_trajectory( + self, turn_messages: List[Message], workflow_state: Dict + ) -> Dict[str, float]: """ Scores a single turn's trajectory based on the specified reward logic. - Steps 0-2: Format reward (0.2 for correct & successful tool call, 0 otherwise). @@ -482,56 +750,93 @@ class BinderBenchEnv(BaseEnv): internal_step = workflow_state.get("current_internal_step") last_tool_success = workflow_state.get("last_tool_success", False) - assistant_msg_dict = next((m for m in reversed(turn_messages) if m.get("role") == "assistant"), None) + assistant_msg_dict = next( + (m for m in reversed(turn_messages) if m.get("role") == "assistant"), None + ) expected_tool_for_step = { 0: "predict_target_structure_alphafold2", 1: "design_binder_backbone_rfdiffusion", 2: "design_binder_sequence_proteinmpnn", - 3: "evaluate_binder_complex_alphafold2_multimer" + 3: "evaluate_binder_complex_alphafold2_multimer", }.get(internal_step) called_tool_name = None if assistant_msg_dict and assistant_msg_dict.get("tool_calls"): tool_calls_list = assistant_msg_dict.get("tool_calls") - if tool_calls_list and isinstance(tool_calls_list, list) and len(tool_calls_list) > 0: + if ( + tool_calls_list + and isinstance(tool_calls_list, list) + and len(tool_calls_list) > 0 + ): function_call_dict = tool_calls_list[0].get("function") if function_call_dict and isinstance(function_call_dict, dict): - called_tool_name = function_call_dict.get("name") + called_tool_name = function_call_dict.get("name") if internal_step < 3: if last_tool_success and called_tool_name == expected_tool_for_step: detailed_scores["overall_reward"] = 0.2 - logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Correct tool '{called_tool_name}' used successfully. Reward: 0.2") + logger.info( + f"Workflow {workflow_state['item_id']}, Step {internal_step}: " + f"Correct tool '{called_tool_name}' used successfully. Reward: 0.2" + ) else: detailed_scores["overall_reward"] = 0.0 if not last_tool_success and called_tool_name: - logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Tool '{called_tool_name}' execution failed. Reward: 0.0") + logger.warning( + f"Workflow {workflow_state['item_id']}, Step {internal_step}: " + f"Tool '{called_tool_name}' execution failed. Reward: 0.0" + ) elif called_tool_name != expected_tool_for_step: - logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: Incorrect tool '{called_tool_name}' used (expected '{expected_tool_for_step}'). Reward: 0.0") + logger.warning( + f"Workflow {workflow_state['item_id']}, Step {internal_step}: " + f"Incorrect tool '{called_tool_name}' used (expected '{expected_tool_for_step}'). " + f"Reward: 0.0" + ) elif not called_tool_name and expected_tool_for_step: - logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step}: No tool called, but expected '{expected_tool_for_step}'. Reward: 0.0") + logger.warning( + f"Workflow {workflow_state['item_id']}, Step {internal_step}: " + f"No tool called, but expected '{expected_tool_for_step}'. Reward: 0.0" + ) elif internal_step == 3: - if workflow_state.get("complex_evaluated") and last_tool_success and called_tool_name == expected_tool_for_step: + if ( + workflow_state.get("complex_evaluated") + and last_tool_success + and called_tool_name == expected_tool_for_step + ): plddt = workflow_state.get("af2_multimer_plddt", 0.0) detailed_scores["raw_plddt"] = plddt if plddt > 90.0: detailed_scores["overall_reward"] = 1.0 elif plddt > 50.0: - detailed_scores["overall_reward"] = 0.0 + (plddt - 50.0) * (1.0 - 0.0) / (90.0 - 50.0) - detailed_scores["overall_reward"] = max(0.0, min(detailed_scores["overall_reward"], 1.0)) + detailed_scores["overall_reward"] = 0.0 + (plddt - 50.0) * ( + 1.0 - 0.0 + ) / (90.0 - 50.0) + detailed_scores["overall_reward"] = max( + 0.0, min(detailed_scores["overall_reward"], 1.0) + ) else: detailed_scores["overall_reward"] = 0.0 - logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}") + logger.info( + f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): " + f"pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}" + ) else: detailed_scores["overall_reward"] = 0.0 - logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): Evaluation failed or wrong tool. Reward: 0.0. Last tool success: {last_tool_success}, Called: {called_tool_name}") + logger.warning( + f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): " + f"Evaluation failed or wrong tool. Reward: 0.0. Last tool success: {last_tool_success}, " + f"Called: {called_tool_name}" + ) else: - logger.error(f"Workflow {workflow_state['item_id']}: Invalid internal_step {internal_step} in scoring.") + logger.error( + f"Workflow {workflow_state['item_id']}: " + f"Invalid internal_step {internal_step} in scoring." + ) detailed_scores["overall_reward"] = -1.0 return detailed_scores @@ -554,9 +859,11 @@ class BinderBenchEnv(BaseEnv): logger.info(f"Running evaluation for {self.name}...") if not self.completed_episode_metrics: logger.info("No completed episodes to evaluate since last evaluation.") - self.eval_metrics = [] # Ensure eval_metrics is an empty list if no new data + self.eval_metrics = ( + [] + ) # Ensure eval_metrics is an empty list if no new data if self.config.use_wandb: - await self.wandb_log({}) # Log that no eval data was present this cycle + await self.wandb_log({}) # Log that no eval data was present this cycle return plddts, cumulative_rewards, workflow_successes = [], [], [] @@ -569,13 +876,23 @@ class BinderBenchEnv(BaseEnv): workflow_successes.append(0.0) cumulative_rewards.append(ep_state.get("cumulative_reward", 0.0)) - self.eval_metrics = [] # Reset class member for current evaluation results + self.eval_metrics = [] # Reset class member for current evaluation results if plddts: self.eval_metrics.append(("eval/avg_plddt", sum(plddts) / len(plddts))) if cumulative_rewards: - self.eval_metrics.append(("eval/avg_cumulative_reward", sum(cumulative_rewards) / len(cumulative_rewards))) + self.eval_metrics.append( + ( + "eval/avg_cumulative_reward", + sum(cumulative_rewards) / len(cumulative_rewards), + ) + ) if workflow_successes: - self.eval_metrics.append(("eval/workflow_success_rate", sum(workflow_successes) / len(workflow_successes))) + self.eval_metrics.append( + ( + "eval/workflow_success_rate", + sum(workflow_successes) / len(workflow_successes), + ) + ) logger.info(f"Evaluation complete. Calculated metrics: {self.eval_metrics}") @@ -584,10 +901,12 @@ class BinderBenchEnv(BaseEnv): self.completed_episode_metrics.clear() - async def add_rollouts_for_wandb(self, - scored_data_group: ScoredDataGroup = None, - item_id: Item = None, - data_for_log: Dict = None): + async def add_rollouts_for_wandb( + self, + scored_data_group: ScoredDataGroup = None, + item_id: Item = None, + data_for_log: Dict = None, + ): """Adds a workflow summary to the wandb rollout buffer. This method has two modes of operation: @@ -620,7 +939,9 @@ class BinderBenchEnv(BaseEnv): workflow_state = self.episodes_state[item_id] if workflow_state is None: - logger.debug(f"No workflow_state available for WandB logging (item_id={item_id})") + logger.debug( + f"No workflow_state available for WandB logging (item_id={item_id})" + ) return target_seq = workflow_state.get("target_sequence", "N/A") @@ -630,28 +951,46 @@ class BinderBenchEnv(BaseEnv): last_turn_messages_str = "No messages." try: - if workflow_state.get("turn_messages_history") and len(workflow_state["turn_messages_history"]) > 0: + if ( + workflow_state.get("turn_messages_history") + and len(workflow_state["turn_messages_history"]) > 0 + ): last_turn_convo = workflow_state["turn_messages_history"][-1] last_turn_messages_str = "\n---\n".join( - [f"{msg.get('role', 'unknown')}: {str(msg.get('content', ''))[:200]}..." for msg in last_turn_convo] + [ + f"{msg.get('role', 'unknown')}: {str(msg.get('content', ''))[:200]}..." + for msg in last_turn_convo + ] ) except Exception as e: logger.error(f"Error processing messages for WandB: {e}") last_turn_messages_str = "Error processing messages" - target_preview = target_seq[:30] + "..." if isinstance(target_seq, str) and len(target_seq) > 30 else target_seq + target_preview = ( + target_seq[:30] + "..." + if isinstance(target_seq, str) and len(target_seq) > 30 + else target_seq + ) designed_binder_data = workflow_state.get("designed_binder_sequence", "N/A") binder_preview = "N/A" if isinstance(designed_binder_data, list) and designed_binder_data: first_chain_seq = str(designed_binder_data[0]) - preview_text = first_chain_seq[:30] + "..." if len(first_chain_seq) > 30 else first_chain_seq + preview_text = ( + first_chain_seq[:30] + "..." + if len(first_chain_seq) > 30 + else first_chain_seq + ) if len(designed_binder_data) > 1: binder_preview = f"{len(designed_binder_data)} chains: {preview_text}" else: binder_preview = preview_text elif isinstance(designed_binder_data, str) and designed_binder_data != "N/A": - binder_preview = designed_binder_data[:30] + "..." if len(designed_binder_data) > 30 else designed_binder_data + binder_preview = ( + designed_binder_data[:30] + "..." + if len(designed_binder_data) > 30 + else designed_binder_data + ) if item_id is None: item_id = workflow_state.get("item_id", "unknown-id") @@ -663,7 +1002,7 @@ class BinderBenchEnv(BaseEnv): binder_preview, f"{plddt:.2f}", f"{cumulative_reward:.3f}", - last_turn_messages_str + last_turn_messages_str, ) ) if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep: @@ -672,15 +1011,21 @@ class BinderBenchEnv(BaseEnv): async def create_rollout_table(self, wandb_metrics: Dict) -> Dict: """Creates a wandb.Table from the buffered rollouts.""" if hasattr(self, "rollouts_for_wandb") and self.rollouts_for_wandb: - columns = ["Item ID", "Target (Preview)", "Designed Binder (Preview)", - "Final pLDDT", "Cumulative Reward", "Last Turn Messages"] + columns = [ + "Item ID", + "Target (Preview)", + "Designed Binder (Preview)", + "Final pLDDT", + "Cumulative Reward", + "Last Turn Messages", + ] table = wandb.Table(columns=columns) for rollout_tuple in self.rollouts_for_wandb: table.add_data(*rollout_tuple) table_key = f"env_rollouts/{self.wandb_prepend}/completed_workflows" if self.wandb_prepend is None and hasattr(self, "name"): - table_key = f"env_rollouts/{self.name}/completed_workflows" + table_key = f"env_rollouts/{self.name}/completed_workflows" wandb_metrics[table_key] = table self.rollouts_for_wandb.clear() @@ -695,5 +1040,6 @@ class BinderBenchEnv(BaseEnv): await super().wandb_log(wandb_metrics) + if __name__ == "__main__": BinderBenchEnv.cli() diff --git a/environments/community/protein_design/tool_definitions.py b/environments/community/protein_design/tool_definitions.py new file mode 100644 index 00000000..59c42a78 --- /dev/null +++ b/environments/community/protein_design/tool_definitions.py @@ -0,0 +1,92 @@ +PREDICT_TARGET_STRUCTURE_TOOL = { + "type": "function", + "function": { + "name": "predict_target_structure_alphafold2", + "description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.", + "parameters": { + "type": "object", + "properties": { + "sequence": { + "type": "string", + "description": "Amino acid sequence of the target protein.", + }, + }, + "required": ["sequence"], + }, + }, +} + +DESIGN_BINDER_BACKBONE_TOOL = { + "type": "function", + "function": { + "name": "design_binder_backbone_rfdiffusion", + "description": ( + "Generates a novel protein binder backbone using RFDiffusion, " + "conditioned on the target protein structure." + ), + "parameters": { + "type": "object", + "properties": { + "contigs": { + "type": "string", + "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70').", + }, + "hotspot_residues": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional hotspot residues (e.g., ['A50', 'A52']).", + }, + }, + "required": ["contigs"], + }, + }, +} + +DESIGN_BINDER_SEQUENCE_TOOL = { + "type": "function", + "function": { + "name": "design_binder_sequence_proteinmpnn", + "description": "Designs an amino acid sequence for the generated binder backbone.", + "parameters": { + "type": "object", + "properties": { + "sampling_temp": { + "type": "array", + "items": {"type": "number"}, + "description": ( + "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]." + ), + } + }, + "required": [], + }, + }, +} + +EVALUATE_COMPLEX_TOOL = { + "type": "function", + "function": { + "name": "evaluate_binder_complex_alphafold2_multimer", + "description": ( + "Predicts the complex structure of target and designed binder, " + "providing quality scores." + ), + "parameters": { + "type": "object", + "properties": { + "relax_prediction": { + "type": "boolean", + "description": "Whether to relax the prediction. Default True.", + } + }, + "required": [], + }, + }, +} + +ALL_TOOLS_LIST = [ + PREDICT_TARGET_STRUCTURE_TOOL, + DESIGN_BINDER_BACKBONE_TOOL, + DESIGN_BINDER_SEQUENCE_TOOL, + EVALUATE_COMPLEX_TOOL, +] diff --git a/environments/hack0/protein_design_env/tool_executor.py b/environments/community/protein_design/tool_executor.py similarity index 52% rename from environments/hack0/protein_design_env/tool_executor.py rename to environments/community/protein_design/tool_executor.py index cbfcff3d..f65e07dd 100644 --- a/environments/hack0/protein_design_env/tool_executor.py +++ b/environments/community/protein_design/tool_executor.py @@ -1,19 +1,26 @@ import logging -import json import re -from typing import Dict, Any, List, Tuple, Optional, Union from pathlib import Path -from environments.hack0.protein_design_env.models.alphafold2 import call_alphafold2 -from environments.hack0.protein_design_env.models.rfdiffusion import call_rfdiffusion -from environments.hack0.protein_design_env.models.proteinmpnn import call_proteinmpnn -from environments.hack0.protein_design_env.models.alphafold2_multimer import call_alphafold2_multimer -from environments.hack0.protein_design_env.utils.pdb_utils import get_pdb_chain_details +from typing import Dict, List, Optional, Tuple + +from .models.alphafold2 import call_alphafold2 +from .models.alphafold2_multimer import call_alphafold2_multimer +from .models.proteinmpnn import call_proteinmpnn +from .models.rfdiffusion import call_rfdiffusion +from .utils.pdb_utils import get_pdb_chain_details logger = logging.getLogger(__name__) + class ToolExecutor: - def __init__(self, nim_api_key: str, api_timeout: int, polling_interval: int, - output_dir: Path, debug_protein_design_calls: bool): + def __init__( + self, + nim_api_key: str, + api_timeout: int, + polling_interval: int, + output_dir: Path, + debug_protein_design_calls: bool, + ): self.nim_api_key = nim_api_key self.api_timeout = api_timeout self.polling_interval = polling_interval @@ -21,18 +28,23 @@ class ToolExecutor: self.debug_protein_design_calls = debug_protein_design_calls self._debug_af2m_call_count = 0 - def _validate_rfd_contigs(self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]: + def _validate_rfd_contigs( + self, contigs_str: str, target_chain_details: Dict[str, Dict[str, int]] + ) -> Optional[str]: """ Validates the RFDiffusion contigs string against target PDB chain details. Returns None if valid, or an error message string if invalid. """ - if not contigs_str: return "Contigs string is empty." + if not contigs_str: + return "Contigs string is empty." target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)") - active_contig_parts = contigs_str.split('/') # Split by binder definition markers + active_contig_parts = contigs_str.split( + "/" + ) # Split by binder definition markers for part in active_contig_parts: - chain_segments_in_part = part.strip().split(' ') + chain_segments_in_part = part.strip().split(" ") for segment_text in chain_segments_in_part: segment_text = segment_text.strip() if not segment_text or segment_text.isdigit(): @@ -45,42 +57,61 @@ class ToolExecutor: seg_end = int(seg_end_str) if seg_chain_id not in target_chain_details: - return f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. Valid: {list(target_chain_details.keys())}." + return ( + f"Chain '{seg_chain_id}' in contig segment '{segment_text}' not in target. " + f"Valid: {list(target_chain_details.keys())}." + ) chain_min = target_chain_details[seg_chain_id]["min_residue"] chain_max = target_chain_details[seg_chain_id]["max_residue"] - if not (chain_min <= seg_start <= chain_max and chain_min <= seg_end <= chain_max and seg_start <= seg_end): - return (f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' " - f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}.") + if not ( + chain_min <= seg_start <= chain_max + and chain_min <= seg_end <= chain_max + and seg_start <= seg_end + ): + return ( + f"Residue range {seg_start}-{seg_end} for chain '{seg_chain_id}' in '{segment_text}' " + f"is invalid/out of bounds. Chain '{seg_chain_id}' actual range: {chain_min}-{chain_max}." + ) return None - def _validate_rfd_hotspots(self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]]) -> Optional[str]: + def _validate_rfd_hotspots( + self, hotspot_list: List[str], target_chain_details: Dict[str, Dict[str, int]] + ) -> Optional[str]: """ Validates hotspot residues (e.g., ["A50", "B25"]) against target PDB chain details. Returns None if valid, or an error message string if invalid. """ - if not hotspot_list: return None + if not hotspot_list: + return None hotspot_pattern = re.compile(r"([A-Za-z0-9])(\d+)") for hotspot_str in hotspot_list: - match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip + match = hotspot_pattern.fullmatch(hotspot_str.strip()) # Add strip if not match: - return f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')." + return ( + f"Hotspot '{hotspot_str}' is not in expected format (e.g., 'A50')." + ) hs_chain_id, hs_res_num_str = match.groups() hs_res_num = int(hs_res_num_str) if hs_chain_id not in target_chain_details: - return f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. Valid: {list(target_chain_details.keys())}." + return ( + f"Chain '{hs_chain_id}' for hotspot '{hotspot_str}' not in target. " + f"Valid: {list(target_chain_details.keys())}." + ) chain_min = target_chain_details[hs_chain_id]["min_residue"] chain_max = target_chain_details[hs_chain_id]["max_residue"] if not (chain_min <= hs_res_num <= chain_max): - return (f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') " - f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}.") + return ( + f"Residue {hs_res_num} for hotspot '{hotspot_str}' (chain '{hs_chain_id}') " + f"out of bounds. Chain '{hs_chain_id}' actual range: {chain_min}-{chain_max}." + ) return None async def _run_nim_alphafold2(self, args: Dict, workflow_state: Dict) -> Dict: @@ -96,13 +127,18 @@ class ToolExecutor: state_updates = {} if self.debug_protein_design_calls: - logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.") + logger.warning( + f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}." + ) module_dir = Path(__file__).parent fixed_pdb_path = module_dir / "debug_target.pdb" if not fixed_pdb_path.exists(): logger.error(f"Debug mode failed: {fixed_pdb_path} not found.") - tool_output = {"success": False, "error": f"Debug mode failed: Required file {fixed_pdb_path} not found."} + tool_output = { + "success": False, + "error": f"Debug mode failed: Required file {fixed_pdb_path} not found.", + } return {"tool_output": tool_output, "state_updates": state_updates} with open(fixed_pdb_path, "r") as f: @@ -115,7 +151,10 @@ class ToolExecutor: state_updates["target_pdb_preview"] = pdb_preview state_updates["target_structure_predicted"] = True - debug_pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb" + debug_pdb_path = ( + self.output_dir + / f"target_{item_id}_s{current_internal_step}_af2_DEBUG.pdb" + ) with open(debug_pdb_path, "w") as f: f.write(pdb_content) logger.info(f"DEBUG MODE: Copied fixed AlphaFold2 PDB to {debug_pdb_path}") @@ -124,25 +163,31 @@ class ToolExecutor: "success": True, "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "target_pdb_preview": pdb_preview, - "saved_pdb_path": str(debug_pdb_path) + "saved_pdb_path": str(debug_pdb_path), } return {"tool_output": tool_output, "state_updates": state_updates} sequence_from_llm = args.get("sequence") if not sequence_from_llm: - tool_output = {"success": False, "error": "Missing 'sequence' for AlphaFold2."} + tool_output = { + "success": False, + "error": "Missing 'sequence' for AlphaFold2.", + } return {"tool_output": tool_output, "state_updates": state_updates} actual_sequence_to_use = target_sequence_from_state if sequence_from_llm != target_sequence_from_state: logger.warning( f"LLM provided sequence '{sequence_from_llm[:20]}...' for 'predict_target_structure_alphafold2'. " - f"However, this tool will use the canonical target sequence from the workflow state: '{target_sequence_from_state[:20]}...'" + f"However, this tool will use the canonical target sequence from the workflow state: " + f"'{target_sequence_from_state[:20]}...'" ) api_result = await call_alphafold2( - sequence=actual_sequence_to_use, api_key=self.nim_api_key, - timeout=self.api_timeout, polling_interval=self.polling_interval + sequence=actual_sequence_to_use, + api_key=self.nim_api_key, + timeout=self.api_timeout, + polling_interval=self.polling_interval, ) if api_result and isinstance(api_result, list) and api_result[0]: pdb_content = api_result[0] @@ -153,20 +198,34 @@ class ToolExecutor: state_updates["target_pdb_preview"] = pdb_preview state_updates["target_structure_predicted"] = True - pdb_path = self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb" - with open(pdb_path, "w") as f: f.write(pdb_content) - logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}") + pdb_path = ( + self.output_dir / f"target_{item_id}_s{current_internal_step}_af2.pdb" + ) + with open(pdb_path, "w") as f: + f.write(pdb_content) + logger.info( + f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. " + f"Chain details: {chain_details}" + ) - tool_output = {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview, "saved_pdb_path": str(pdb_path)} + tool_output = { + "success": True, + "message": "AlphaFold2 prediction complete.", + "target_pdb_preview": pdb_preview, + "saved_pdb_path": str(pdb_path), + } else: - error_detail = api_result.get("error", "AlphaFold2 prediction failed.") if isinstance(api_result, dict) else "AlphaFold2 prediction failed." + error_detail = ( + api_result.get("error", "AlphaFold2 prediction failed.") + if isinstance(api_result, dict) + else "AlphaFold2 prediction failed." + ) logger.error(f"Workflow {item_id}: AlphaFold2 call failed: {error_detail}") tool_output = {"success": False, "error": error_detail} state_updates["target_structure_predicted"] = False return {"tool_output": tool_output, "state_updates": state_updates} - async def _run_nim_rfdiffusion(self, args: Dict, workflow_state: Dict) -> Dict: """ Runs RFDiffusion for binder backbone design. Returns structured output with @@ -182,30 +241,55 @@ class ToolExecutor: contigs_str_from_llm = args.get("contigs") if not target_pdb_content: - tool_output = {"success": False, "error": "Target PDB not found in state for RFDiffusion."} + tool_output = { + "success": False, + "error": "Target PDB not found in state for RFDiffusion.", + } return {"tool_output": tool_output, "state_updates": state_updates} if not contigs_str_from_llm: - tool_output = {"success": False, "error": "Missing 'contigs' for RFDiffusion."} + tool_output = { + "success": False, + "error": "Missing 'contigs' for RFDiffusion.", + } return {"tool_output": tool_output, "state_updates": state_updates} - validation_error = self._validate_rfd_contigs(contigs_str_from_llm, target_chain_details) + validation_error = self._validate_rfd_contigs( + contigs_str_from_llm, target_chain_details + ) if validation_error: - logger.warning(f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. Contigs: '{contigs_str_from_llm}'") - tool_output = {"success": False, "error": f"Invalid contigs: {validation_error}"} + logger.warning( + f"RFDiffusion contigs validation failed for item {item_id}: {validation_error}. " + f"Contigs: '{contigs_str_from_llm}'" + ) + tool_output = { + "success": False, + "error": f"Invalid contigs: {validation_error}", + } return {"tool_output": tool_output, "state_updates": state_updates} hotspot_residues = args.get("hotspot_residues") if hotspot_residues: - hotspot_validation_error = self._validate_rfd_hotspots(hotspot_residues, target_chain_details) + hotspot_validation_error = self._validate_rfd_hotspots( + hotspot_residues, target_chain_details + ) if hotspot_validation_error: - logger.warning(f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. Hotspots: {hotspot_residues}") - tool_output = {"success": False, "error": f"Invalid hotspots: {hotspot_validation_error}"} + logger.warning( + f"RFDiffusion hotspot validation failed for item {item_id}: {hotspot_validation_error}. " + f"Hotspots: {hotspot_residues}" + ) + tool_output = { + "success": False, + "error": f"Invalid hotspots: {hotspot_validation_error}", + } return {"tool_output": tool_output, "state_updates": state_updates} api_result = await call_rfdiffusion( - input_pdb=target_pdb_content, api_key=self.nim_api_key, - contigs=contigs_str_from_llm, hotspot_res=hotspot_residues, - timeout=self.api_timeout, polling_interval=self.polling_interval + input_pdb=target_pdb_content, + api_key=self.nim_api_key, + contigs=contigs_str_from_llm, + hotspot_res=hotspot_residues, + timeout=self.api_timeout, + polling_interval=self.polling_interval, ) if api_result and "output_pdb" in api_result: @@ -217,20 +301,34 @@ class ToolExecutor: state_updates["binder_pdb_preview"] = binder_pdb_preview state_updates["binder_backbone_designed"] = True - pdb_path = self.output_dir / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb" - with open(pdb_path, "w") as f: f.write(binder_pdb) + pdb_path = ( + self.output_dir + / f"binder_backbone_{item_id}_s{current_internal_step}_rfd.pdb" + ) + with open(pdb_path, "w") as f: + f.write(binder_pdb) logger.info(f"Workflow {item_id}: RFDiffusion PDB saved to {pdb_path}") - tool_output = {"success": True, "message": "RFDiffusion complete.", "binder_backbone_pdb_preview": binder_pdb_preview, "saved_pdb_path": str(pdb_path)} + tool_output = { + "success": True, + "message": "RFDiffusion complete.", + "binder_backbone_pdb_preview": binder_pdb_preview, + "saved_pdb_path": str(pdb_path), + } else: - error_detail = api_result.get("error", "RFDiffusion failed.") if isinstance(api_result, dict) else "RFDiffusion failed." - logger.error(f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}") + error_detail = ( + api_result.get("error", "RFDiffusion failed.") + if isinstance(api_result, dict) + else "RFDiffusion failed." + ) + logger.error( + f"Workflow {item_id}: RFDiffusion call failed: {error_detail}. API Result: {api_result}" + ) tool_output = {"success": False, "error": error_detail} state_updates["binder_backbone_designed"] = False return {"tool_output": tool_output, "state_updates": state_updates} - async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict: """ Runs ProteinMPNN for binder sequence design. Returns structured output with @@ -244,20 +342,33 @@ class ToolExecutor: state_updates = {} if not binder_pdb: - tool_output = {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."} + tool_output = { + "success": False, + "error": "Binder backbone PDB not found for ProteinMPNN.", + } return {"tool_output": tool_output, "state_updates": state_updates} sampling_temp_list = args.get("sampling_temp", [0.1]) api_result = await call_proteinmpnn( - input_pdb=binder_pdb, api_key=self.nim_api_key, + input_pdb=binder_pdb, + api_key=self.nim_api_key, sampling_temp=sampling_temp_list, - timeout=self.api_timeout, polling_interval=self.polling_interval + timeout=self.api_timeout, + polling_interval=self.polling_interval, ) if not (api_result and "mfasta" in api_result): - error_detail = api_result.get("error", "ProteinMPNN call failed or no mfasta in result.") if isinstance(api_result, dict) else "PMPNN call failed" - logger.error(f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}") + error_detail = ( + api_result.get( + "error", "ProteinMPNN call failed or no mfasta in result." + ) + if isinstance(api_result, dict) + else "PMPNN call failed" + ) + logger.error( + f"Workflow {item_id}: ProteinMPNN call failed: {error_detail}. API Result: {api_result}" + ) tool_output = {"success": False, "error": error_detail} state_updates["binder_sequence_designed"] = False return {"tool_output": tool_output, "state_updates": state_updates} @@ -268,12 +379,15 @@ class ToolExecutor: current_sequence_parts: List[str] = [] for line_content in fasta_content.splitlines(): line = line_content.strip() - if not line: continue + if not line: + continue if line.startswith(">"): if current_header and current_sequence_parts: full_sequence_line = "".join(current_sequence_parts) score_match = re.search(r"global_score=([-\d.]+)", current_header) - global_score = float(score_match.group(1)) if score_match else -float('inf') + global_score = ( + float(score_match.group(1)) if score_match else -float("inf") + ) entries.append((global_score, current_header, full_sequence_line)) current_header = line current_sequence_parts = [] @@ -282,7 +396,7 @@ class ToolExecutor: if current_header and current_sequence_parts: full_sequence_line = "".join(current_sequence_parts) score_match = re.search(r"global_score=([-\d.]+)", current_header) - global_score = float(score_match.group(1)) if score_match else -float('inf') + global_score = float(score_match.group(1)) if score_match else -float("inf") entries.append((global_score, current_header, full_sequence_line)) if not entries: @@ -292,36 +406,59 @@ class ToolExecutor: entries.sort(key=lambda x: x[0], reverse=True) best_global_score, best_header, best_full_sequence_line = entries[0] - logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'") + logger.info( + f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) " + f"from header: '{best_header}' -> Seq line: '{best_full_sequence_line}'" + ) - parsed_binder_chains = [s.strip() for s in best_full_sequence_line.split('/') if s.strip()] + parsed_binder_chains = [ + s.strip() for s in best_full_sequence_line.split("/") if s.strip() + ] - if not parsed_binder_chains or not all(s and s.isalpha() and s.isupper() for s in parsed_binder_chains): - tool_output = {"success": False, "error": f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. Parsed: {parsed_binder_chains}"} + if not parsed_binder_chains or not all( + s and s.isalpha() and s.isupper() for s in parsed_binder_chains + ): + tool_output = { + "success": False, + "error": ( + f"Invalid binder chains from PMPNN after parsing '{best_full_sequence_line}'. " + f"Parsed: {parsed_binder_chains}" + ), + } state_updates["binder_sequence_designed"] = False return {"tool_output": tool_output, "state_updates": state_updates} state_updates["designed_binder_sequence"] = parsed_binder_chains state_updates["binder_sequence_designed"] = True - fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta" - with open(fasta_path, "w") as f: f.write(fasta_content) - logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}") + fasta_path = ( + self.output_dir + / f"binder_sequence_{item_id}_s{current_internal_step}_pmpnn.fasta" + ) + with open(fasta_path, "w") as f: + f.write(fasta_content) + logger.info( + f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. " + f"Selected binder chains: {parsed_binder_chains}" + ) - preview = parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A" + preview = ( + parsed_binder_chains[0][:60] + "..." if parsed_binder_chains else "N/A" + ) if len(parsed_binder_chains) > 1: preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))" tool_output = { "success": True, - "message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).", + "message": ( + f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f})." + ), "designed_binder_sequence_list": parsed_binder_chains, "designed_binder_sequence_preview": preview, - "saved_fasta_path": str(fasta_path) + "saved_fasta_path": str(fasta_path), } return {"tool_output": tool_output, "state_updates": state_updates} - async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict: item_id = workflow_state["item_id"] current_internal_step = workflow_state["current_internal_step"] @@ -331,16 +468,31 @@ class ToolExecutor: tool_output = {} state_updates = {} - if not target_seq or not designed_binder_chains_list or not isinstance(designed_binder_chains_list, list): - tool_output = {"success": False, "error": "Missing or invalid sequences for AF2-Multimer."} + if ( + not target_seq + or not designed_binder_chains_list + or not isinstance(designed_binder_chains_list, list) + ): + tool_output = { + "success": False, + "error": "Missing or invalid sequences for AF2-Multimer.", + } return {"tool_output": tool_output, "state_updates": state_updates} all_input_sequences_for_multimer = [target_seq] + designed_binder_chains_list for i, seq_to_validate in enumerate(all_input_sequences_for_multimer): - if not (seq_to_validate and isinstance(seq_to_validate, str) and seq_to_validate.isalpha() and seq_to_validate.isupper()): - error_msg = (f"Sequence {i+1} (part of target/binder complex) is invalid: " - f"'{str(seq_to_validate)[:30]}...'. Contains non-alpha/lowercase, is empty, or not a string.") + if not ( + seq_to_validate + and isinstance(seq_to_validate, str) + and seq_to_validate.isalpha() + and seq_to_validate.isupper() + ): + error_msg = ( + f"Sequence {i+1} (part of target/binder complex) is invalid: " + f"'{str(seq_to_validate)[:30]}...'. " + f"Contains non-alpha/lowercase, is empty, or not a string." + ) logger.error(f"Workflow {item_id}: {error_msg}") tool_output = {"success": False, "error": error_msg} return {"tool_output": tool_output, "state_updates": state_updates} @@ -350,28 +502,41 @@ class ToolExecutor: if self.debug_protein_design_calls: self._debug_af2m_call_count += 1 mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2 - success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})" + success_message = ( + f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality " + f"mock results (call #{self._debug_af2m_call_count})" + ) debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb" debug_pdb_path = self.output_dir / debug_pdb_filename try: with open(debug_pdb_path, "w") as f: - f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n") - logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}") + f.write( + f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n" + ) + logger.info( + f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}" + ) state_updates["complex_pdb_content_path"] = str(debug_pdb_path) except IOError as e: - logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}") + logger.error( + f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}" + ) # If saving fails, don't set the path, but can still proceed with mock pLDDT state_updates["complex_pdb_content_path"] = None - state_updates["af2_multimer_plddt"] = mock_plddt state_updates["complex_evaluated"] = True tool_output = { - "success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}", + "success": True, + "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}", "plddt": mock_plddt, - "complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None + "complex_file_path": ( + str(debug_pdb_path) + if state_updates["complex_pdb_content_path"] + else None + ), } return {"tool_output": tool_output, "state_updates": state_updates} @@ -380,36 +545,57 @@ class ToolExecutor: api_key=self.nim_api_key, relax_prediction=relax, timeout=self.api_timeout, - polling_interval=self.polling_interval + polling_interval=self.polling_interval, ) - if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False): + if api_result is None or ( + isinstance(api_result, dict) and api_result.get("success") is False + ): error_detail = "AF2-Multimer call failed or returned None." if isinstance(api_result, dict): - error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.") + error_detail = api_result.get( + "error", "AF2-Multimer call failed with unspecified error." + ) detail_info = api_result.get("detail", "") - if detail_info: error_detail += f" Details: {detail_info}" + if detail_info: + error_detail += f" Details: {detail_info}" - logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}") + logger.error( + f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. " + f"API Result: {api_result}" + ) tool_output = {"success": False, "error": error_detail} state_updates["complex_evaluated"] = False return {"tool_output": tool_output, "state_updates": state_updates} all_structures_info = api_result.get("structures") if not all_structures_info or not isinstance(all_structures_info, list): - message = api_result.get("message", "No structures returned from AF2-Multimer process.") + message = api_result.get( + "message", "No structures returned from AF2-Multimer process." + ) logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}") if not all_structures_info and isinstance(all_structures_info, list): - tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None} - state_updates["af2_multimer_plddt"] = 0.0 - state_updates["complex_evaluated"] = True - state_updates["complex_pdb_content_path"] = None + tool_output = { + "success": True, + "message": ( + "AF2-Multimer ran, but no PDB structures were produced by the API." + ), + "plddt": 0.0, + "complex_file_path": None, + } + state_updates["af2_multimer_plddt"] = 0.0 + state_updates["complex_evaluated"] = True + state_updates["complex_pdb_content_path"] = None else: - tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."} + tool_output = { + "success": False, + "error": ( + "AF2-Multimer returned unexpected data or no structures." + ), + } state_updates["complex_evaluated"] = False return {"tool_output": tool_output, "state_updates": state_updates} - best_structure_info = None highest_plddt = -1.0 @@ -420,8 +606,14 @@ class ToolExecutor: best_structure_info = struct_info if best_structure_info is None: - logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.") - tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."} + logger.error( + f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, " + f"though structures were present." + ) + tool_output = { + "success": False, + "error": ("No valid structure with pLDDT in AF2-Multimer results."), + } state_updates["complex_evaluated"] = False return {"tool_output": tool_output, "state_updates": state_updates} @@ -430,53 +622,81 @@ class ToolExecutor: best_model_idx = best_structure_info.get("model_index", "NA") if not best_pdb_content: - logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.") - tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."} + logger.error( + f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, " + f"pLDDT {best_plddt:.2f}) found, but PDB content is missing." + ) + tool_output = { + "success": False, + "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content.", + } state_updates["complex_evaluated"] = False state_updates["af2_multimer_plddt"] = best_plddt return {"tool_output": tool_output, "state_updates": state_updates} - complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb" + complex_pdb_filename = ( + f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_" + f"pLDDT{best_plddt:.2f}.pdb" + ) complex_pdb_path = self.output_dir / complex_pdb_filename try: - with open(complex_pdb_path, "w", encoding='utf-8') as f: + with open(complex_pdb_path, "w", encoding="utf-8") as f: f.write(best_pdb_content) - logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}") + logger.info( + f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) " + f"with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}" + ) state_updates["complex_pdb_content_path"] = str(complex_pdb_path) state_updates["af2_multimer_plddt"] = best_plddt state_updates["complex_evaluated"] = True - complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}" + complex_quality_message = ( + f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) " + f"with pLDDT: {best_plddt:.2f}" + ) tool_output = { "success": True, "message": complex_quality_message, "plddt": best_plddt, "complex_file_path": str(complex_pdb_path), - "selected_model_index": best_model_idx + "selected_model_index": best_model_idx, } except IOError as e: - logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}") - tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"} + logger.error( + f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, " + f"pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}" + ) + tool_output = { + "success": False, + "error": f"Failed to save best complex PDB: {e}", + } state_updates["af2_multimer_plddt"] = best_plddt state_updates["complex_pdb_content_path"] = None state_updates["complex_evaluated"] = True return {"tool_output": tool_output, "state_updates": state_updates} - - async def dispatch_tool_call(self, tool_name: str, args: Dict, workflow_state: Dict) -> Dict: + async def dispatch_tool_call( + self, tool_name: str, args: Dict, workflow_state: Dict + ) -> Dict: """Main dispatch method for executing tools.""" item_id = workflow_state["item_id"] internal_step = workflow_state["current_internal_step"] - logger.info(f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, Step {internal_step} with args: {args}") + logger.info( + f"ToolExecutor: Dispatching tool '{tool_name}' for workflow {item_id}, " + f"Step {internal_step} with args: {args}" + ) if not self.nim_api_key: return { - "tool_output": {"success": False, "error": "NIM API key not configured in ToolExecutor."}, - "state_updates": {} + "tool_output": { + "success": False, + "error": "NIM API key not configured in ToolExecutor.", + }, + "state_updates": {}, } if tool_name == "predict_target_structure_alphafold2": @@ -488,8 +708,13 @@ class ToolExecutor: elif tool_name == "evaluate_binder_complex_alphafold2_multimer": return await self._run_nim_af2_multimer(args, workflow_state) else: - logger.error(f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}") + logger.error( + f"ToolExecutor: Unknown tool name '{tool_name}' for workflow {item_id}" + ) return { - "tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"}, - "state_updates": {} + "tool_output": { + "success": False, + "error": f"Unknown tool name: {tool_name}", + }, + "state_updates": {}, } diff --git a/environments/hack0/protein_design_env/utils/__init__.py b/environments/community/protein_design/utils/__init__.py similarity index 74% rename from environments/hack0/protein_design_env/utils/__init__.py rename to environments/community/protein_design/utils/__init__.py index 2610ebd5..dd517116 100644 --- a/environments/hack0/protein_design_env/utils/__init__.py +++ b/environments/community/protein_design/utils/__init__.py @@ -2,4 +2,4 @@ from .pdb_utils import get_pdb_chain_details -__all__ = ["get_pdb_chain_details"] \ No newline at end of file +__all__ = ["get_pdb_chain_details"] diff --git a/environments/hack0/protein_design_env/utils/api_utils.py b/environments/community/protein_design/utils/api_utils.py similarity index 71% rename from environments/hack0/protein_design_env/utils/api_utils.py rename to environments/community/protein_design/utils/api_utils.py index 6f4bd2ee..b6eac70a 100644 --- a/environments/hack0/protein_design_env/utils/api_utils.py +++ b/environments/community/protein_design/utils/api_utils.py @@ -1,13 +1,13 @@ -import os import logging -import yaml -from pathlib import Path +import os from typing import Optional + from dotenv import load_dotenv load_dotenv() logger = logging.getLogger(__name__) + def load_api_key() -> Optional[str]: """ Load the NVIDIA NIM API key from environment variables. @@ -17,8 +17,10 @@ def load_api_key() -> Optional[str]: """ api_key = os.environ.get("NVIDIA_NIM_API_KEY") if not api_key: - logger.error("NVIDIA_NIM_API_KEY not found in environment variables. " - "Please set it in your .env file.") + logger.error( + "NVIDIA_NIM_API_KEY not found in environment variables. " + "Please set it in your .env file." + ) return None return api_key diff --git a/environments/hack0/protein_design_env/utils/pdb_utils.py b/environments/community/protein_design/utils/pdb_utils.py similarity index 80% rename from environments/hack0/protein_design_env/utils/pdb_utils.py rename to environments/community/protein_design/utils/pdb_utils.py index 8a59d88e..c24e3603 100644 --- a/environments/hack0/protein_design_env/utils/pdb_utils.py +++ b/environments/community/protein_design/utils/pdb_utils.py @@ -1,9 +1,12 @@ import logging -from typing import Dict, Tuple, List, Set, Union +from typing import Dict, Set, Tuple, Union logger = logging.getLogger(__name__) -def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]: + +def get_pdb_chain_details( + pdb_content: str, preview_lines: int = 10 +) -> Tuple[Dict[str, Dict[str, int]], str]: """ Parses PDB content to extract detailed information for each chain. @@ -27,7 +30,8 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di if line.startswith("ATOM"): atom_lines.append(line) chain_id = line[21:22].strip() - if not chain_id: chain_id = " " + if not chain_id: + chain_id = " " atom_name = line[12:16].strip() try: residue_num = int(line[22:26].strip()) @@ -39,7 +43,11 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di except ValueError: logger.warning(f"Could not parse residue number from PDB line: {line}") continue - elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"): + elif ( + line.startswith("HEADER") + or line.startswith("TITLE") + or line.startswith("COMPND") + ): header_lines.append(line) chain_details: Dict[str, Dict[str, int]] = {} @@ -51,14 +59,16 @@ def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Di chain_details[chain_id] = { "min_residue": min_res, "max_residue": max_res, - "length": length + "length": length, } else: logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.") - preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)] + preview_str_parts = header_lines[: min(len(header_lines), preview_lines // 2)] remaining_preview_lines = preview_lines - len(preview_str_parts) - preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)]) + preview_str_parts.extend( + atom_lines[: min(len(atom_lines), remaining_preview_lines)] + ) pdb_preview = "\n".join(preview_str_parts) if len(pdb_content.splitlines()) > preview_lines: pdb_preview += "\n..." diff --git a/environments/hack0/protein_design_env/models/__init__.py b/environments/hack0/protein_design_env/models/__init__.py deleted file mode 100644 index 2cd7109b..00000000 --- a/environments/hack0/protein_design_env/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Protein design model API modules.""" \ No newline at end of file diff --git a/environments/hack0/protein_design_env/tool_definitions.py b/environments/hack0/protein_design_env/tool_definitions.py deleted file mode 100644 index 86ee0b59..00000000 --- a/environments/hack0/protein_design_env/tool_definitions.py +++ /dev/null @@ -1,67 +0,0 @@ -PREDICT_TARGET_STRUCTURE_TOOL = { - "type": "function", - "function": { - "name": "predict_target_structure_alphafold2", - "description": "Predicts the 3D structure of the target protein sequence using AlphaFold2.", - "parameters": { - "type": "object", - "properties": { - "sequence": {"type": "string", "description": "Amino acid sequence of the target protein."}, - }, - "required": ["sequence"] - } - } -} - -DESIGN_BINDER_BACKBONE_TOOL = { - "type": "function", - "function": { - "name": "design_binder_backbone_rfdiffusion", - "description": "Generates a novel protein binder backbone using RFDiffusion, conditioned on the target protein structure.", - "parameters": { - "type": "object", - "properties": { - "contigs": {"type": "string", "description": "RFDiffusion contigs (e.g., 'A1-100/0 50-70')."}, - "hotspot_residues": {"type": "array", "items": {"type": "string"}, "description": "Optional hotspot residues (e.g., ['A50', 'A52'])."}, - }, - "required": ["contigs"] - } - } -} - -DESIGN_BINDER_SEQUENCE_TOOL = { - "type": "function", - "function": { - "name": "design_binder_sequence_proteinmpnn", - "description": "Designs an amino acid sequence for the generated binder backbone.", - "parameters": { - "type": "object", - "properties": { - "sampling_temp": {"type": "array", "items": {"type": "number"}, "description": "Sampling temperatures (e.g., [0.1, 0.2]). Default [0.1]."} - }, - "required": [] - } - } -} - -EVALUATE_COMPLEX_TOOL = { - "type": "function", - "function": { - "name": "evaluate_binder_complex_alphafold2_multimer", - "description": "Predicts the complex structure of target and designed binder, providing quality scores.", - "parameters": { - "type": "object", - "properties": { - "relax_prediction": {"type": "boolean", "description": "Whether to relax the prediction. Default True."} - }, - "required": [] - } - } -} - -ALL_TOOLS_LIST = [ - PREDICT_TARGET_STRUCTURE_TOOL, - DESIGN_BINDER_BACKBONE_TOOL, - DESIGN_BINDER_SEQUENCE_TOOL, - EVALUATE_COMPLEX_TOOL -]