mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
427 lines
18 KiB
Python
427 lines
18 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import aiohttp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
|
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
|
|
|
|
|
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|
"""
|
|
Splits a string containing concatenated PDB file contents.
|
|
Assumes models are separated by "ENDMDL" or just "END" for the last/single model.
|
|
"""
|
|
pdbs = []
|
|
current_pdb_lines = []
|
|
if not concatenated_pdb_str:
|
|
return []
|
|
|
|
for line in concatenated_pdb_str.splitlines(keepends=True):
|
|
current_pdb_lines.append(line)
|
|
if line.startswith("ENDMDL") or line.startswith("END "):
|
|
pdbs.append("".join(current_pdb_lines).strip())
|
|
current_pdb_lines = []
|
|
|
|
if current_pdb_lines:
|
|
remaining_pdb = "".join(current_pdb_lines).strip()
|
|
if remaining_pdb:
|
|
pdbs.append(remaining_pdb)
|
|
|
|
return [pdb for pdb in pdbs if pdb]
|
|
|
|
|
|
def calculate_plddt_from_pdb_string(
|
|
pdb_string: str,
|
|
) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
|
total_plddt = 0.0
|
|
ca_atom_count = 0
|
|
plddt_scores_per_ca: List[float] = []
|
|
plddt_scores_per_chain: Dict[str, List[float]] = {}
|
|
|
|
for line in pdb_string.splitlines():
|
|
if line.startswith("ATOM"):
|
|
atom_name = line[12:16].strip()
|
|
if atom_name == "CA":
|
|
try:
|
|
plddt_value = float(line[60:66].strip())
|
|
total_plddt += plddt_value
|
|
plddt_scores_per_ca.append(plddt_value)
|
|
ca_atom_count += 1
|
|
|
|
chain_id = line[21:22].strip()
|
|
if chain_id not in plddt_scores_per_chain:
|
|
plddt_scores_per_chain[chain_id] = []
|
|
plddt_scores_per_chain[chain_id].append(plddt_value)
|
|
|
|
except ValueError:
|
|
pass
|
|
except IndexError:
|
|
pass
|
|
|
|
if ca_atom_count == 0:
|
|
return 0.0, [], {}
|
|
|
|
average_plddt = total_plddt / ca_atom_count
|
|
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
|
|
|
|
|
async def _process_pdb_and_scores_from_api(
|
|
pdb_contents: List[str],
|
|
job_id: str,
|
|
api_response_json: Optional[Dict[str, Any]] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Processes a list of PDB strings received from the API.
|
|
- Calculates pLDDT scores for each PDB string.
|
|
- Does NOT save files to disk.
|
|
- Returns a dictionary containing a list of structures, each with its PDB content and scores.
|
|
"""
|
|
results: Dict[str, Any] = {"structures": []}
|
|
|
|
if (
|
|
not pdb_contents
|
|
or not isinstance(pdb_contents, list)
|
|
or not all(isinstance(s, str) for s in pdb_contents)
|
|
):
|
|
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
|
|
return {
|
|
"success": False,
|
|
"error": "No valid PDB content strings from API.",
|
|
"structures": [],
|
|
}
|
|
|
|
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
|
|
|
|
for i, pdb_str in enumerate(pdb_contents):
|
|
if not pdb_str.strip():
|
|
logger.debug(f"Skipping empty PDB string at index {i} for job {job_id}.")
|
|
continue
|
|
|
|
structure_data: Dict[str, Any] = {"model_index": i, "pdb_content": pdb_str}
|
|
|
|
avg_plddt, plddts_per_ca_residue, plddts_by_chain = (
|
|
calculate_plddt_from_pdb_string(pdb_str)
|
|
)
|
|
|
|
structure_data["average_plddt"] = avg_plddt
|
|
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue
|
|
structure_data["plddt_scores_per_chain"] = plddts_by_chain
|
|
|
|
avg_plddt_per_chain = {}
|
|
for chain_id, chain_plddts in plddts_by_chain.items():
|
|
if chain_plddts:
|
|
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
|
|
else:
|
|
avg_plddt_per_chain[chain_id] = 0.0
|
|
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
|
|
|
results["structures"].append(structure_data)
|
|
|
|
if results["structures"]:
|
|
logger.info(
|
|
f"Successfully processed and calculated pLDDTs for "
|
|
f"{len(results['structures'])} structures for job {job_id}."
|
|
)
|
|
else:
|
|
logger.warning(f"No structures were processed for job {job_id}.")
|
|
return {
|
|
"success": True,
|
|
"message": "No PDB structures found in API response to process.",
|
|
"structures": [],
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
async def call_alphafold2_multimer(
|
|
sequences: List[str],
|
|
api_key: str,
|
|
algorithm: str = "jackhmmer",
|
|
e_value: float = 0.0001,
|
|
iterations: int = 1,
|
|
databases: List[str] = ["uniref90", "small_bfd", "mgnify"],
|
|
relax_prediction: bool = True,
|
|
selected_models: Optional[List[int]] = None,
|
|
url: str = DEFAULT_URL,
|
|
status_url: str = DEFAULT_STATUS_URL,
|
|
polling_interval: int = 30,
|
|
timeout: int = 3600,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
|
The API returns JSON with a list of PDB strings.
|
|
This function processes them to calculate pLDDT scores and returns a dictionary
|
|
containing a list of structures, each with its PDB content and computed scores.
|
|
File saving is handled by the caller (ToolExecutor).
|
|
"""
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
"NVCF-POLL-SECONDS": "300",
|
|
}
|
|
data: Dict[str, Any] = {
|
|
"sequences": sequences,
|
|
"algorithm": algorithm,
|
|
"e_value": e_value,
|
|
"iterations": iterations,
|
|
"databases": databases,
|
|
"relax_prediction": relax_prediction,
|
|
}
|
|
if selected_models is not None:
|
|
data["selected_models"] = selected_models
|
|
logger.info(f"Using selected_models: {selected_models}")
|
|
|
|
try:
|
|
initial_post_timeout = min(timeout, 600)
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url, json=data, headers=headers, timeout=initial_post_timeout
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info("AlphaFold2-Multimer job completed synchronously.")
|
|
content_type = response.headers.get("Content-Type", "").lower()
|
|
|
|
if "application/json" in content_type:
|
|
api_response_json_payload = await response.json()
|
|
if not isinstance(api_response_json_payload, list):
|
|
if (
|
|
isinstance(api_response_json_payload, dict)
|
|
and "error" in api_response_json_payload
|
|
):
|
|
logger.error(
|
|
f"Sync API call returned error: "
|
|
f"{api_response_json_payload['error']}"
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": api_response_json_payload["error"],
|
|
"detail": api_response_json_payload.get(
|
|
"detail", ""
|
|
),
|
|
}
|
|
return {
|
|
"success": False,
|
|
"error": "Sync JSON response not a list of PDBs as expected.",
|
|
}
|
|
|
|
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
|
|
return await _process_pdb_and_scores_from_api(
|
|
pdb_contents=api_response_json_payload,
|
|
job_id=req_id_sync,
|
|
api_response_json=None,
|
|
)
|
|
else:
|
|
err_text = await response.text()
|
|
logger.error(
|
|
f"Sync response unexpected content type: {content_type}. "
|
|
f"Response: {err_text[:500]}"
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": f"Sync response unexpected content type: {content_type}",
|
|
"detail": err_text,
|
|
}
|
|
|
|
elif response.status == 202:
|
|
req_id = response.headers.get("nvcf-reqid")
|
|
if req_id:
|
|
logger.info(
|
|
f"AlphaFold2-Multimer job submitted, request ID: {req_id}"
|
|
)
|
|
return await _poll_job_status(
|
|
req_id=req_id,
|
|
headers=headers,
|
|
status_url=status_url,
|
|
polling_interval=polling_interval,
|
|
overall_timeout=timeout,
|
|
)
|
|
else:
|
|
logger.error("No request ID in 202 response headers")
|
|
return {
|
|
"success": False,
|
|
"error": "No request ID in 202 response headers",
|
|
}
|
|
else:
|
|
logger.error(
|
|
f"Error calling AlphaFold2-Multimer API (POST): {response.status}"
|
|
)
|
|
text = await response.text()
|
|
logger.error(f"Response: {text}")
|
|
return {
|
|
"success": False,
|
|
"error": f"Error calling API: {response.status}",
|
|
"detail": text,
|
|
}
|
|
except asyncio.TimeoutError:
|
|
logger.error("Timeout during AlphaFold2-Multimer API (initial POST).")
|
|
return {"success": False, "error": "Timeout during initial API request"}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Exception during AlphaFold2-Multimer API call (initial POST): {e}",
|
|
exc_info=True,
|
|
)
|
|
return {"success": False, "error": f"Exception during API call: {str(e)}"}
|
|
|
|
|
|
async def _poll_job_status(
|
|
req_id: str,
|
|
headers: Dict[str, str],
|
|
status_url: str,
|
|
polling_interval: int = 30,
|
|
overall_timeout: int = 3600,
|
|
) -> Optional[Dict[str, Any]]:
|
|
start_time = asyncio.get_event_loop().time()
|
|
per_status_request_timeout = 600
|
|
logger.info(
|
|
f"Polling job {req_id}. Individual status check timeout: "
|
|
f"{per_status_request_timeout}s, Polling interval: {polling_interval}s, "
|
|
f"Overall timeout: {overall_timeout}s"
|
|
)
|
|
|
|
while True:
|
|
current_loop_time = asyncio.get_event_loop().time()
|
|
elapsed_time = current_loop_time - start_time
|
|
|
|
if elapsed_time >= overall_timeout:
|
|
logger.error(
|
|
f"Overall polling timeout of {overall_timeout}s exceeded for "
|
|
f"job {req_id}."
|
|
)
|
|
return {"success": False, "error": "Overall polling timeout exceeded."}
|
|
|
|
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
|
|
current_status_check_timeout = min(
|
|
per_status_request_timeout, remaining_time_for_overall_timeout
|
|
)
|
|
|
|
if current_status_check_timeout <= 0:
|
|
logger.error(
|
|
f"Not enough time left for another status check for job {req_id} "
|
|
f"within overall_timeout."
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": "Not enough time for status check within overall timeout.",
|
|
}
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
logger.debug(
|
|
f"Checking status for {req_id} with timeout "
|
|
f"{current_status_check_timeout}s."
|
|
)
|
|
async with session.get(
|
|
f"{status_url}/{req_id}",
|
|
headers=headers,
|
|
timeout=current_status_check_timeout,
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info(
|
|
f"AlphaFold2-Multimer job {req_id} completed (status 200)."
|
|
)
|
|
if response.content_type == "application/json":
|
|
try:
|
|
api_response_json_payload = await response.json()
|
|
if not isinstance(api_response_json_payload, list):
|
|
if (
|
|
isinstance(api_response_json_payload, dict)
|
|
and "error" in api_response_json_payload
|
|
):
|
|
logger.error(
|
|
f"Job {req_id}: API returned error: "
|
|
f"{api_response_json_payload['error']}"
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": api_response_json_payload["error"],
|
|
"detail": api_response_json_payload.get(
|
|
"detail", ""
|
|
),
|
|
}
|
|
logger.error(
|
|
f"Job {req_id}: Expected API response to be a list of PDB strings, "
|
|
f"got {type(api_response_json_payload)}."
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": "API response was not a list of PDB strings.",
|
|
}
|
|
|
|
return await _process_pdb_and_scores_from_api(
|
|
pdb_contents=api_response_json_payload,
|
|
job_id=req_id,
|
|
api_response_json=None,
|
|
)
|
|
except json.JSONDecodeError:
|
|
logger.error(
|
|
f"Job {req_id}: Failed to decode JSON response from API.",
|
|
exc_info=True,
|
|
)
|
|
raw_text = await response.text()
|
|
return {
|
|
"success": False,
|
|
"error": "Failed to decode JSON response.",
|
|
"detail": raw_text[:500],
|
|
}
|
|
else:
|
|
raw_text = await response.text()
|
|
logger.error(
|
|
f"Job {req_id}: Unexpected content type {response.content_type}. "
|
|
f"Expected application/json. Response: {raw_text[:500]}"
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": f"Unexpected content type: {response.content_type}",
|
|
"detail": raw_text,
|
|
}
|
|
elif response.status == 202:
|
|
try:
|
|
job_status_json = await response.json()
|
|
percent_complete = job_status_json.get(
|
|
"percentComplete", "N/A"
|
|
)
|
|
status_message = job_status_json.get("status", "running")
|
|
logger.debug(
|
|
f"Job {req_id} status: {status_message} ({percent_complete}% complete). "
|
|
f"Polling again in {polling_interval}s."
|
|
)
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
logger.debug(
|
|
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). "
|
|
f"Polling again in {polling_interval}s."
|
|
)
|
|
await asyncio.sleep(polling_interval)
|
|
else:
|
|
text = await response.text()
|
|
logger.error(
|
|
f"Error checking AlphaFold2-Multimer job status {req_id}: "
|
|
f"HTTP {response.status} - {text}"
|
|
)
|
|
return {
|
|
"success": False,
|
|
"error": f"Status check failed with HTTP {response.status}",
|
|
"detail": text,
|
|
}
|
|
except asyncio.TimeoutError:
|
|
logger.warning(
|
|
f"Client-side timeout ({current_status_check_timeout}s) during status check for "
|
|
f"job {req_id}. Retrying poll after {polling_interval}s sleep."
|
|
)
|
|
await asyncio.sleep(polling_interval)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(
|
|
f"Client error polling job status for {req_id}: {e}. "
|
|
f"Retrying poll after {polling_interval}s.",
|
|
exc_info=True,
|
|
)
|
|
await asyncio.sleep(polling_interval)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Unexpected error polling job status {req_id}: {e}", exc_info=True
|
|
)
|
|
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|