mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
additional fixes to alphafold2_multimer and tool_executor
This commit is contained in:
parent
6783a077cc
commit
b01023ad3a
2 changed files with 30 additions and 56 deletions
|
|
@ -4,14 +4,13 @@ import aiohttp
|
|||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from pathlib import Path # Keep Path for type hinting if used elsewhere, but not for output_dir here
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
# _split_pdb_content remains the same
|
||||
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
||||
"""
|
||||
Splits a string containing concatenated PDB file contents.
|
||||
|
|
@ -24,7 +23,7 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
|
||||
for line in concatenated_pdb_str.splitlines(keepends=True):
|
||||
current_pdb_lines.append(line)
|
||||
if line.startswith("ENDMDL") or line.startswith("END "): # Fixed: added space for "END "
|
||||
if line.startswith("ENDMDL") or line.startswith("END "):
|
||||
pdbs.append("".join(current_pdb_lines).strip())
|
||||
current_pdb_lines = []
|
||||
|
||||
|
|
@ -36,7 +35,6 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
return [pdb for pdb in pdbs if pdb]
|
||||
|
||||
|
||||
# calculate_plddt_from_pdb_string remains the same
|
||||
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
||||
total_plddt = 0.0
|
||||
ca_atom_count = 0
|
||||
|
|
@ -70,10 +68,10 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
|
|||
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
||||
|
||||
async def _process_pdb_and_scores_from_api(
|
||||
pdb_contents: List[str], # This is the list of PDB strings from the API
|
||||
pdb_contents: List[str],
|
||||
job_id: str,
|
||||
api_response_json: Optional[Dict[str, Any]] = None # Not currently used, but kept for consistency
|
||||
) -> Optional[Dict[str, Any]]: # Return structure: {"structures": [...]}
|
||||
api_response_json: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Processes a list of PDB strings received from the API.
|
||||
- Calculates pLDDT scores for each PDB string.
|
||||
|
|
@ -86,7 +84,6 @@ async def _process_pdb_and_scores_from_api(
|
|||
|
||||
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
|
||||
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
|
||||
# Return a structure indicating failure or empty results, consistent with how call_alphafold2_multimer handles errors
|
||||
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []}
|
||||
|
||||
|
||||
|
|
@ -98,8 +95,8 @@ async def _process_pdb_and_scores_from_api(
|
|||
continue
|
||||
|
||||
structure_data: Dict[str, Any] = {
|
||||
"model_index": i, # Keep model_index for identification
|
||||
"pdb_content": pdb_str # Store the actual PDB content
|
||||
"model_index": i,
|
||||
"pdb_content": pdb_str
|
||||
}
|
||||
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
||||
|
|
@ -115,9 +112,6 @@ async def _process_pdb_and_scores_from_api(
|
|||
else:
|
||||
avg_plddt_per_chain[chain_id] = 0.0
|
||||
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
||||
|
||||
# REMOVED FILE SAVING LOGIC HERE
|
||||
# structure_data["saved_pdb_path"] = ... (NO LONGER SET HERE)
|
||||
|
||||
results["structures"].append(structure_data)
|
||||
|
||||
|
|
@ -125,11 +119,9 @@ async def _process_pdb_and_scores_from_api(
|
|||
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
|
||||
else:
|
||||
logger.warning(f"No structures were processed for job {job_id}.")
|
||||
# Ensure a consistent return structure even if no structures processed
|
||||
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []}
|
||||
|
||||
|
||||
return results # This will contain {"structures": [...list of dicts with pdb_content and scores...]}
|
||||
return results
|
||||
|
||||
async def call_alphafold2_multimer(
|
||||
sequences: List[str],
|
||||
|
|
@ -144,7 +136,6 @@ async def call_alphafold2_multimer(
|
|||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 30,
|
||||
timeout: int = 3600
|
||||
# REMOVED output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
||||
|
|
@ -186,17 +177,16 @@ async def call_alphafold2_multimer(
|
|||
if "application/json" in content_type:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
# Handle case where API might return {"error": ...} directly in a sync 200
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
|
||||
|
||||
|
||||
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
|
||||
return await _process_pdb_and_scores_from_api( # No output_dir
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id_sync,
|
||||
api_response_json=None # Not used currently
|
||||
api_response_json=None
|
||||
)
|
||||
else:
|
||||
err_text = await response.text()
|
||||
|
|
@ -207,7 +197,7 @@ async def call_alphafold2_multimer(
|
|||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
|
||||
return await _poll_job_status( # No output_dir
|
||||
return await _poll_job_status(
|
||||
req_id=req_id,
|
||||
headers=headers,
|
||||
status_url=status_url,
|
||||
|
|
@ -235,14 +225,12 @@ async def _poll_job_status(
|
|||
status_url: str,
|
||||
polling_interval: int = 30,
|
||||
overall_timeout: int = 3600
|
||||
# REMOVED output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
per_status_request_timeout = 600
|
||||
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
||||
|
||||
while True:
|
||||
# ... (polling logic for time checks remains the same) ...
|
||||
current_loop_time = asyncio.get_event_loop().time()
|
||||
elapsed_time = current_loop_time - start_time
|
||||
|
||||
|
|
@ -265,23 +253,22 @@ async def _poll_job_status(
|
|||
headers=headers,
|
||||
timeout=current_status_check_timeout
|
||||
) as response:
|
||||
if response.status == 200: # Job completed
|
||||
if response.status == 200:
|
||||
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
|
||||
if response.content_type == 'application/json':
|
||||
try:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
# Handle case where API might return {"error": ...} directly
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
|
||||
return {"success": False, "error": "API response was not a list of PDB strings."}
|
||||
|
||||
return await _process_pdb_and_scores_from_api( # No output_dir
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id,
|
||||
api_response_json=None # Not used currently
|
||||
api_response_json=None
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
|
||||
|
|
@ -291,7 +278,6 @@ async def _poll_job_status(
|
|||
raw_text = await response.text()
|
||||
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}")
|
||||
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text}
|
||||
# ... (rest of polling logic: 202, errors, timeouts remains the same) ...
|
||||
elif response.status == 202:
|
||||
try:
|
||||
job_status_json = await response.json()
|
||||
|
|
@ -311,10 +297,10 @@ async def _poll_job_status(
|
|||
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
|
||||
await asyncio.sleep(polling_interval) # Sleep before next attempt
|
||||
await asyncio.sleep(polling_interval)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
|
||||
await asyncio.sleep(polling_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue