mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
366 lines
No EOL
17 KiB
Python
366 lines
No EOL
17 KiB
Python
"""AlphaFold2-Multimer API integration for NVIDIA NIM."""
|
|
|
|
import os
|
|
import logging
|
|
import aiohttp
|
|
import json
|
|
import asyncio
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from pathlib import Path
|
|
import zipfile
|
|
import io
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default URL
|
|
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
|
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
|
|
|
# Helper functions
|
|
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|
"""
|
|
Splits a string containing concatenated PDB file contents.
|
|
Assumes models are separated by "ENDMDL" or just "END" for the last/single model.
|
|
"""
|
|
pdbs = []
|
|
current_pdb_lines = []
|
|
if not concatenated_pdb_str:
|
|
return []
|
|
|
|
for line in concatenated_pdb_str.splitlines(keepends=True):
|
|
current_pdb_lines.append(line)
|
|
if line.startswith("ENDMDL") or line.startswith("END "):
|
|
pdbs.append("".join(current_pdb_lines).strip())
|
|
current_pdb_lines = []
|
|
|
|
if current_pdb_lines:
|
|
remaining_pdb = "".join(current_pdb_lines).strip()
|
|
if remaining_pdb:
|
|
pdbs.append(remaining_pdb)
|
|
|
|
return [pdb for pdb in pdbs if pdb]
|
|
|
|
|
|
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
|
"""
|
|
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
|
|
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- average_plddt (float): Average pLDDT over all C-alpha atoms.
|
|
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
|
|
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
|
|
"""
|
|
total_plddt = 0.0
|
|
ca_atom_count = 0
|
|
plddt_scores_per_ca: List[float] = []
|
|
plddt_scores_per_chain: Dict[str, List[float]] = {}
|
|
|
|
for line in pdb_string.splitlines():
|
|
if line.startswith("ATOM"):
|
|
atom_name = line[12:16].strip()
|
|
if atom_name == "CA":
|
|
try:
|
|
plddt_value = float(line[60:66].strip())
|
|
total_plddt += plddt_value
|
|
plddt_scores_per_ca.append(plddt_value)
|
|
ca_atom_count += 1
|
|
|
|
chain_id = line[21:22].strip()
|
|
if chain_id not in plddt_scores_per_chain:
|
|
plddt_scores_per_chain[chain_id] = []
|
|
plddt_scores_per_chain[chain_id].append(plddt_value)
|
|
|
|
except ValueError:
|
|
pass
|
|
except IndexError:
|
|
pass
|
|
|
|
if ca_atom_count == 0:
|
|
return 0.0, [], {}
|
|
|
|
average_plddt = total_plddt / ca_atom_count
|
|
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
|
|
|
async def _process_nvidia_zip_output(
|
|
zip_content: bytes,
|
|
output_prefix: str,
|
|
response_headers: Optional[Dict[str, str]] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Processes the ZIP file content from NVIDIA NIM.
|
|
- Expects a .response file with concatenated PDBs, or individual PDB files.
|
|
- Extracts PDBs and calculates pLDDT scores for each structure.
|
|
- Returns paths to saved files and calculated pLDDT scores.
|
|
"""
|
|
output_dir = Path(f"./{output_prefix}_results")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save the original ZIP file
|
|
zip_file_path = output_dir / f"{output_prefix}.zip"
|
|
with open(zip_file_path, 'wb') as f:
|
|
f.write(zip_content)
|
|
logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}")
|
|
|
|
# Initialize the results dictionary that call_alphafold2_multimer will return
|
|
results: Dict[str, Any] = {
|
|
"zip_file_path": str(zip_file_path), # Path to the saved original ZIP
|
|
"structures": [], # List to hold info for each PDB structure found
|
|
# We will NOT be trying to parse iptm/ptm from ranking_debug.json
|
|
"iptm_score": None, # Explicitly None, or remove if not needed at all
|
|
"ptm_score": None, # Explicitly None, or remove
|
|
# Optional: Store path to the .response file if it exists and is processed
|
|
"extracted_response_file_path": None,
|
|
}
|
|
|
|
pdb_strings_to_process = []
|
|
|
|
try:
|
|
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
|
|
# First, check for a ".response" file with concatenated PDBs
|
|
response_file_name = None
|
|
for member_name in zf.namelist():
|
|
if member_name.lower().endswith(".response"):
|
|
response_file_name = member_name
|
|
break
|
|
|
|
if response_file_name:
|
|
logger.info(f"Found concatenated response file in ZIP: {response_file_name}")
|
|
response_data = zf.read(response_file_name)
|
|
response_content_str = response_data.decode('utf-8', errors='replace')
|
|
|
|
# Save the raw .response file
|
|
extracted_response_file_path = output_dir / Path(response_file_name).name
|
|
with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp:
|
|
f_resp.write(response_content_str)
|
|
results["extracted_response_file_path"] = str(extracted_response_file_path)
|
|
logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}")
|
|
|
|
pdb_strings_to_process.extend(_split_pdb_content(response_content_str))
|
|
else:
|
|
# If no ".response" file, look for individual PDB files
|
|
logger.info("No .response file found. Looking for individual .pdb files in ZIP.")
|
|
for member_name in zf.namelist():
|
|
if member_name.lower().endswith(".pdb"):
|
|
logger.info(f"Found individual PDB file in ZIP: {member_name}")
|
|
pdb_content_bytes = zf.read(member_name)
|
|
pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace'))
|
|
|
|
if not pdb_strings_to_process:
|
|
logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).")
|
|
return results # Return with empty structures list
|
|
|
|
logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.")
|
|
|
|
for i, pdb_str in enumerate(pdb_strings_to_process):
|
|
if not pdb_str.strip(): # Skip empty PDB strings
|
|
logger.debug(f"Skipping empty PDB string at index {i}.")
|
|
continue
|
|
|
|
structure_data: Dict[str, Any] = {
|
|
"model_index": i, # 0-indexed based on order found
|
|
"pdb_content": pdb_str # Store the raw PDB string
|
|
}
|
|
|
|
# Calculate pLDDT scores using your existing function
|
|
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
|
|
|
structure_data["average_plddt"] = avg_plddt
|
|
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs
|
|
structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT]
|
|
|
|
# Calculate average pLDDT for each chain (already in your previous code, good to keep)
|
|
avg_plddt_per_chain = {}
|
|
for chain_id, chain_plddts in plddts_by_chain.items():
|
|
if chain_plddts: # Avoid division by zero
|
|
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
|
|
else:
|
|
avg_plddt_per_chain[chain_id] = 0.0
|
|
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
|
|
|
# Save the individual PDB string to a file
|
|
pdb_file_name_stem = Path(output_prefix).stem
|
|
# Suffix for rank if multiple models found, otherwise simpler name
|
|
rank_suffix = f"_model_{i}" # Consistent naming for multiple models
|
|
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
|
|
|
|
try:
|
|
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
|
|
f_pdb.write(pdb_str)
|
|
structure_data["saved_pdb_path"] = str(pdb_file_path)
|
|
logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
|
|
except Exception as e_write:
|
|
logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}")
|
|
structure_data["saved_pdb_path"] = None
|
|
|
|
results["structures"].append(structure_data)
|
|
|
|
if results["structures"]:
|
|
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.")
|
|
|
|
except zipfile.BadZipFile:
|
|
logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.")
|
|
# results dictionary is already initialized, will be returned as is, potentially empty.
|
|
except Exception as e:
|
|
logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True)
|
|
# As above, return results which might be partially filled or empty.
|
|
|
|
return results
|
|
|
|
async def call_alphafold2_multimer(
|
|
sequences: List[str],
|
|
api_key: str,
|
|
algorithm: str = "jackhmmer",
|
|
e_value: float = 0.0001,
|
|
iterations: int = 1,
|
|
databases: List[str] = ["uniref90", "small_bfd", "mgnify"],
|
|
relax_prediction: bool = True,
|
|
selected_models: Optional[List[int]] = None,
|
|
url: str = DEFAULT_URL,
|
|
status_url: str = DEFAULT_STATUS_URL,
|
|
polling_interval: int = 30,
|
|
timeout: int = 3600
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
|
Returns a dictionary structured by _process_nvidia_zip_output.
|
|
"""
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
"NVCF-POLL-SECONDS": "300",
|
|
}
|
|
data: Dict[str, Any] = {
|
|
"sequences": sequences,
|
|
"algorithm": algorithm,
|
|
"e_value": e_value,
|
|
"iterations": iterations,
|
|
"databases": databases,
|
|
"relax_prediction": relax_prediction
|
|
}
|
|
if selected_models is not None:
|
|
data["selected_models"] = selected_models
|
|
logger.info(f"Using selected_models: {selected_models}")
|
|
|
|
try:
|
|
initial_post_timeout = min(timeout, 600)
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url,
|
|
json=data,
|
|
headers=headers,
|
|
timeout=initial_post_timeout
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info("AlphaFold2-Multimer job completed synchronously.")
|
|
content = await response.read()
|
|
return await _process_nvidia_zip_output(
|
|
zip_content=content,
|
|
output_prefix="alphafold2_multimer_sync_output",
|
|
response_headers=response.headers
|
|
)
|
|
elif response.status == 202:
|
|
req_id = response.headers.get("nvcf-reqid")
|
|
if req_id:
|
|
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
|
|
return await _poll_job_status(
|
|
req_id=req_id,
|
|
headers=headers,
|
|
status_url=status_url,
|
|
polling_interval=polling_interval,
|
|
overall_timeout=timeout
|
|
)
|
|
else:
|
|
logger.error("No request ID in 202 response headers")
|
|
return None
|
|
else: # Handle other error statuses from POST
|
|
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
|
|
text = await response.text()
|
|
logger.error(f"Response: {text}")
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
|
|
return None
|
|
|
|
async def _poll_job_status(
|
|
req_id: str,
|
|
headers: Dict[str, str],
|
|
status_url: str,
|
|
polling_interval: int = 30,
|
|
overall_timeout: int = 3600
|
|
) -> Optional[Dict[str, Any]]:
|
|
start_time = asyncio.get_event_loop().time()
|
|
# Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it,
|
|
# or a fixed reasonably long duration.
|
|
# The NVCF-POLL-SECONDS in the POST header is 300s.
|
|
# The GET request to /status should also ideally respect a similar long-poll duration from the server.
|
|
status_check_timeout = 330 # seconds (e.g., 5.5 minutes)
|
|
logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
|
|
|
while True:
|
|
current_loop_time = asyncio.get_event_loop().time()
|
|
elapsed_time = current_loop_time - start_time
|
|
|
|
if elapsed_time > overall_timeout:
|
|
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
|
|
return None
|
|
|
|
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
|
|
current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout)
|
|
if current_status_check_timeout <= 0:
|
|
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
|
|
return None
|
|
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
|
|
async with session.get(
|
|
f"{status_url}/{req_id}",
|
|
headers=headers,
|
|
timeout=current_status_check_timeout
|
|
) as response:
|
|
if response.status == 200:
|
|
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
|
|
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
|
|
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
|
|
zip_content_bytes = await response.read()
|
|
return await _process_nvidia_zip_output(
|
|
zip_content=zip_content_bytes,
|
|
output_prefix=f"alphafold2_multimer_output_{req_id}",
|
|
response_headers=response.headers
|
|
)
|
|
elif response.status == 202:
|
|
try:
|
|
job_status_json = await response.json()
|
|
percent_complete = job_status_json.get('percentComplete', 'N/A')
|
|
status_message = job_status_json.get('status', 'running')
|
|
logger.debug(
|
|
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
|
|
)
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
logger.debug(
|
|
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
|
|
)
|
|
await asyncio.sleep(polling_interval)
|
|
else: # Handle other error statuses from status GET
|
|
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}")
|
|
text = await response.text()
|
|
logger.error(f"Response: {text}")
|
|
# Log the error, but continue polling unless it's a fatal client error (4xx other than 429)
|
|
# or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body)
|
|
# For a 504 like you saw, we might want to retry a few times then give up.
|
|
# For now, this will return None on non-200/202, which your test script will catch.
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
|
|
await asyncio.sleep(polling_interval)
|
|
except aiohttp.ClientError as e:
|
|
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
|
|
await asyncio.sleep(polling_interval)
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
|
|
return None |