refactor file saving from alphafold2_multimer to tool_executor

This commit is contained in:
based-tachikoma 2025-05-21 21:23:00 -07:00
parent 227e594ebf
commit 02585947e4
2 changed files with 162 additions and 156 deletions

View file

@ -4,13 +4,14 @@ import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
from pathlib import Path # Keep Path for type hinting if used elsewhere, but not for output_dir here
logger = logging.getLogger(__name__)
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
# _split_pdb_content remains the same
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
"""
Splits a string containing concatenated PDB file contents.
@ -23,7 +24,7 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
for line in concatenated_pdb_str.splitlines(keepends=True):
current_pdb_lines.append(line)
if line.startswith("ENDMDL") or line.startswith("END "):
if line.startswith("ENDMDL") or line.startswith("END "): # Fixed: added space for "END "
pdbs.append("".join(current_pdb_lines).strip())
current_pdb_lines = []
@ -35,17 +36,8 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
return [pdb for pdb in pdbs if pdb]
# calculate_plddt_from_pdb_string remains the same
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
"""
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
Returns:
A tuple containing:
- average_plddt (float): Average pLDDT over all C-alpha atoms.
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
"""
total_plddt = 0.0
ca_atom_count = 0
plddt_scores_per_ca: List[float] = []
@ -78,46 +70,25 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_pdb_and_scores_from_api(
pdb_contents: List[str],
pdb_contents: List[str], # This is the list of PDB strings from the API
job_id: str,
api_response_json: Optional[Dict[str, Any]] = None,
output_dir: Optional[Path] = None
) -> Optional[Dict[str, Any]]:
api_response_json: Optional[Dict[str, Any]] = None # Not currently used, but kept for consistency
) -> Optional[Dict[str, Any]]: # Return structure: {"structures": [...]}
"""
Processes a list of PDB strings received from the API JSON response.
- The API responds with a direct list of PDB strings, not a nested JSON structure
- This function saves each PDB to disk and calculates pLDDT scores
- The api_response_json parameter is for potential future use if the API adds metadata
- The output_dir parameter allows for customizing where files are saved
Processes a list of PDB strings received from the API.
- Calculates pLDDT scores for each PDB string.
- Does NOT save files to disk.
- Returns a dictionary containing a list of structures, each with its PDB content and scores.
"""
if output_dir is None:
# Default behavior if no output dir is provided
output_dir_name = f"alphafold2_multimer_output_{job_id}_results"
output_dir = Path(f"./{output_dir_name}")
# Make sure the directory exists
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving AlphaFold2-Multimer results for job {job_id} to directory: {output_dir}")
results: Dict[str, Any] = {
"output_directory": str(output_dir),
"structures": [],
"ptm_score": None,
"iptm_score": None,
"structures": []
}
if isinstance(api_response_json, dict):
logger.info(f"Attempting to extract additional scores from API JSON response for job {job_id}.")
results["ptm_score"] = api_response_json.get("ptm")
results["iptm_score"] = api_response_json.get("iptm")
if results["ptm_score"] is not None or results["iptm_score"] is not None:
logger.info(f"Extracted from API JSON: pTM={results['ptm_score']}, ipTM={results['iptm_score']}")
else:
logger.info(f"No additional API JSON dictionary provided or it's not a dict for job {job_id}; ptm/iptm scores will be None unless found elsewhere.")
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
return results # Return with empty structures list
# Return a structure indicating failure or empty results, consistent with how call_alphafold2_multimer handles errors
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []}
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
@ -127,8 +98,8 @@ async def _process_pdb_and_scores_from_api(
continue
structure_data: Dict[str, Any] = {
"model_index": i,
"pdb_content": pdb_str
"model_index": i, # Keep model_index for identification
"pdb_content": pdb_str # Store the actual PDB content
}
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
@ -144,26 +115,21 @@ async def _process_pdb_and_scores_from_api(
else:
avg_plddt_per_chain[chain_id] = 0.0
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
pdb_file_name_stem = f"alphafold2_multimer_output_{job_id}"
rank_suffix = f"_model_{i+1}"
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
try:
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
f_pdb.write(pdb_str)
structure_data["saved_pdb_path"] = str(pdb_file_path)
logger.info(f"Saved PDB model {i+1} for job {job_id} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
except Exception as e_write:
logger.error(f"Failed to write PDB file {pdb_file_path} for job {job_id}: {e_write}")
structure_data["saved_pdb_path"] = None
# REMOVED FILE SAVING LOGIC HERE
# structure_data["saved_pdb_path"] = ... (NO LONGER SET HERE)
results["structures"].append(structure_data)
if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
else:
logger.warning(f"No structures were processed for job {job_id}.")
# Ensure a consistent return structure even if no structures processed
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []}
return results
return results # This will contain {"structures": [...list of dicts with pdb_content and scores...]}
async def call_alphafold2_multimer(
sequences: List[str],
@ -177,13 +143,15 @@ async def call_alphafold2_multimer(
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30,
timeout: int = 3600,
output_dir: Optional[Path] = None
timeout: int = 3600
# REMOVED output_dir: Optional[Path] = None
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2-Multimer API.
The API now returns JSON with a list of PDB strings, which we process to calculate pLDDT scores.
Returns a dictionary with processed PDB strings and computed scores.
The API returns JSON with a list of PDB strings.
This function processes them to calculate pLDDT scores and returns a dictionary
containing a list of structures, each with its PDB content and computed scores.
File saving is handled by the caller (ToolExecutor).
"""
headers = {
"content-type": "application/json",
@ -213,28 +181,33 @@ async def call_alphafold2_multimer(
) as response:
if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.")
logger.info(f"SYNC Final response headers: {response.headers}")
content_type = response.headers.get("Content-Type", "").lower()
logger.info(f"SYNC Final response content-type: {content_type}")
if "application/json" in content_type:
api_response_json_payload = await response.json()
if not isinstance(api_response_json_payload, list):
# Handle case where API might return {"error": ...} directly in a sync 200
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}")
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
req_id_sync = response.headers.get("nvcf-reqid", "sync_job") # Get req_id or make one up
return await _process_pdb_and_scores_from_api(
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
return await _process_pdb_and_scores_from_api( # No output_dir
pdb_contents=api_response_json_payload,
job_id=req_id_sync,
api_response_json=None,
output_dir=output_dir
api_response_json=None # Not used currently
)
else:
return {"success": False, "error": f"Sync response unexpected content type: {content_type}"}
err_text = await response.text()
logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}")
return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text}
elif response.status == 202:
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
return await _poll_job_status(
return await _poll_job_status( # No output_dir
req_id=req_id,
headers=headers,
status_url=status_url,
@ -262,12 +235,14 @@ async def _poll_job_status(
status_url: str,
polling_interval: int = 30,
overall_timeout: int = 3600
# REMOVED output_dir: Optional[Path] = None
) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time()
per_status_request_timeout = 600
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
while True:
# ... (polling logic for time checks remains the same) ...
current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time
@ -290,37 +265,33 @@ async def _poll_job_status(
headers=headers,
timeout=current_status_check_timeout
) as response:
if response.status == 200:
if response.status == 200: # Job completed
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
if response.content_type == 'application/json':
try:
api_response_json_payload = await response.json()
logger.debug(f"API JSON Response for job {req_id}: {str(api_response_json_payload)[:500]}...")
if not isinstance(api_response_json_payload, list):
# Handle case where API might return {"error": ...} directly
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}")
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
return {"success": False, "error": "API response was not a list of PDB strings."}
return await _process_pdb_and_scores_from_api(
return await _process_pdb_and_scores_from_api( # No output_dir
pdb_contents=api_response_json_payload,
job_id=req_id,
api_response_json=None,
output_dir=output_dir
api_response_json=None # Not used currently
)
except json.JSONDecodeError:
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
raw_text = await response.text()
logger.debug(f"Raw text response: {raw_text[:500]}")
return {"success": False, "error": "Failed to decode JSON response."}
return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]}
else:
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json.")
raw_text = await response.text()
logger.debug(f"Raw text response: {raw_text[:500]}")
return {"success": False, "error": f"Unexpected content type: {response.content_type}"}
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}")
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text}
# ... (rest of polling logic: 202, errors, timeouts remains the same) ...
elif response.status == 202:
try:
job_status_json = await response.json()
@ -340,10 +311,10 @@ async def _poll_job_status(
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
await asyncio.sleep(polling_interval)
await asyncio.sleep(polling_interval) # Sleep before next attempt
except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
await asyncio.sleep(polling_interval)
except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}