atropos/environments/hack0/protein_design_env/models/alphafold2_multimer.py
2025-05-20 20:12:59 -07:00

366 lines
No EOL
17 KiB
Python

"""AlphaFold2-Multimer API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import zipfile
import io
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
# Helper functions
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
"""
Splits a string containing concatenated PDB file contents.
Assumes models are separated by "ENDMDL" or just "END" for the last/single model.
"""
pdbs = []
current_pdb_lines = []
if not concatenated_pdb_str:
return []
for line in concatenated_pdb_str.splitlines(keepends=True):
current_pdb_lines.append(line)
if line.startswith("ENDMDL") or line.startswith("END "):
pdbs.append("".join(current_pdb_lines).strip())
current_pdb_lines = []
if current_pdb_lines:
remaining_pdb = "".join(current_pdb_lines).strip()
if remaining_pdb:
pdbs.append(remaining_pdb)
return [pdb for pdb in pdbs if pdb]
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
"""
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
Returns:
A tuple containing:
- average_plddt (float): Average pLDDT over all C-alpha atoms.
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
"""
total_plddt = 0.0
ca_atom_count = 0
plddt_scores_per_ca: List[float] = []
plddt_scores_per_chain: Dict[str, List[float]] = {}
for line in pdb_string.splitlines():
if line.startswith("ATOM"):
atom_name = line[12:16].strip()
if atom_name == "CA":
try:
plddt_value = float(line[60:66].strip())
total_plddt += plddt_value
plddt_scores_per_ca.append(plddt_value)
ca_atom_count += 1
chain_id = line[21:22].strip()
if chain_id not in plddt_scores_per_chain:
plddt_scores_per_chain[chain_id] = []
plddt_scores_per_chain[chain_id].append(plddt_value)
except ValueError:
pass
except IndexError:
pass
if ca_atom_count == 0:
return 0.0, [], {}
average_plddt = total_plddt / ca_atom_count
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_nvidia_zip_output(
zip_content: bytes,
output_prefix: str,
response_headers: Optional[Dict[str, str]] = None
) -> Optional[Dict[str, Any]]:
"""
Processes the ZIP file content from NVIDIA NIM.
- Expects a .response file with concatenated PDBs, or individual PDB files.
- Extracts PDBs and calculates pLDDT scores for each structure.
- Returns paths to saved files and calculated pLDDT scores.
"""
output_dir = Path(f"./{output_prefix}_results")
output_dir.mkdir(parents=True, exist_ok=True)
# Save the original ZIP file
zip_file_path = output_dir / f"{output_prefix}.zip"
with open(zip_file_path, 'wb') as f:
f.write(zip_content)
logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}")
# Initialize the results dictionary that call_alphafold2_multimer will return
results: Dict[str, Any] = {
"zip_file_path": str(zip_file_path), # Path to the saved original ZIP
"structures": [], # List to hold info for each PDB structure found
# We will NOT be trying to parse iptm/ptm from ranking_debug.json
"iptm_score": None, # Explicitly None, or remove if not needed at all
"ptm_score": None, # Explicitly None, or remove
# Optional: Store path to the .response file if it exists and is processed
"extracted_response_file_path": None,
}
pdb_strings_to_process = []
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
# First, check for a ".response" file with concatenated PDBs
response_file_name = None
for member_name in zf.namelist():
if member_name.lower().endswith(".response"):
response_file_name = member_name
break
if response_file_name:
logger.info(f"Found concatenated response file in ZIP: {response_file_name}")
response_data = zf.read(response_file_name)
response_content_str = response_data.decode('utf-8', errors='replace')
# Save the raw .response file
extracted_response_file_path = output_dir / Path(response_file_name).name
with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp:
f_resp.write(response_content_str)
results["extracted_response_file_path"] = str(extracted_response_file_path)
logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}")
pdb_strings_to_process.extend(_split_pdb_content(response_content_str))
else:
# If no ".response" file, look for individual PDB files
logger.info("No .response file found. Looking for individual .pdb files in ZIP.")
for member_name in zf.namelist():
if member_name.lower().endswith(".pdb"):
logger.info(f"Found individual PDB file in ZIP: {member_name}")
pdb_content_bytes = zf.read(member_name)
pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace'))
if not pdb_strings_to_process:
logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).")
return results # Return with empty structures list
logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.")
for i, pdb_str in enumerate(pdb_strings_to_process):
if not pdb_str.strip(): # Skip empty PDB strings
logger.debug(f"Skipping empty PDB string at index {i}.")
continue
structure_data: Dict[str, Any] = {
"model_index": i, # 0-indexed based on order found
"pdb_content": pdb_str # Store the raw PDB string
}
# Calculate pLDDT scores using your existing function
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
structure_data["average_plddt"] = avg_plddt
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs
structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT]
# Calculate average pLDDT for each chain (already in your previous code, good to keep)
avg_plddt_per_chain = {}
for chain_id, chain_plddts in plddts_by_chain.items():
if chain_plddts: # Avoid division by zero
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
else:
avg_plddt_per_chain[chain_id] = 0.0
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
# Save the individual PDB string to a file
pdb_file_name_stem = Path(output_prefix).stem
# Suffix for rank if multiple models found, otherwise simpler name
rank_suffix = f"_model_{i}" # Consistent naming for multiple models
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
try:
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
f_pdb.write(pdb_str)
structure_data["saved_pdb_path"] = str(pdb_file_path)
logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
except Exception as e_write:
logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}")
structure_data["saved_pdb_path"] = None
results["structures"].append(structure_data)
if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.")
except zipfile.BadZipFile:
logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.")
# results dictionary is already initialized, will be returned as is, potentially empty.
except Exception as e:
logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True)
# As above, return results which might be partially filled or empty.
return results
async def call_alphafold2_multimer(
sequences: List[str],
api_key: str,
algorithm: str = "jackhmmer",
e_value: float = 0.0001,
iterations: int = 1,
databases: List[str] = ["uniref90", "small_bfd", "mgnify"],
relax_prediction: bool = True,
selected_models: Optional[List[int]] = None,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30,
timeout: int = 3600
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2-Multimer API.
Returns a dictionary structured by _process_nvidia_zip_output.
"""
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
data: Dict[str, Any] = {
"sequences": sequences,
"algorithm": algorithm,
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction
}
if selected_models is not None:
data["selected_models"] = selected_models
logger.info(f"Using selected_models: {selected_models}")
try:
initial_post_timeout = min(timeout, 600)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=initial_post_timeout
) as response:
if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.")
content = await response.read()
return await _process_nvidia_zip_output(
zip_content=content,
output_prefix="alphafold2_multimer_sync_output",
response_headers=response.headers
)
elif response.status == 202:
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
overall_timeout=timeout
)
else:
logger.error("No request ID in 202 response headers")
return None
else: # Handle other error statuses from POST
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except asyncio.TimeoutError:
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
return None
except Exception as e:
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 30,
overall_timeout: int = 3600
) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time()
# Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it,
# or a fixed reasonably long duration.
# The NVCF-POLL-SECONDS in the POST header is 300s.
# The GET request to /status should also ideally respect a similar long-poll duration from the server.
status_check_timeout = 330 # seconds (e.g., 5.5 minutes)
logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
while True:
current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time
if elapsed_time > overall_timeout:
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
return None
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout)
if current_status_check_timeout <= 0:
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
return None
try:
async with aiohttp.ClientSession() as session:
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=current_status_check_timeout
) as response:
if response.status == 200:
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
zip_content_bytes = await response.read()
return await _process_nvidia_zip_output(
zip_content=zip_content_bytes,
output_prefix=f"alphafold2_multimer_output_{req_id}",
response_headers=response.headers
)
elif response.status == 202:
try:
job_status_json = await response.json()
percent_complete = job_status_json.get('percentComplete', 'N/A')
status_message = job_status_json.get('status', 'running')
logger.debug(
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
)
except (aiohttp.ContentTypeError, json.JSONDecodeError):
logger.debug(
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
)
await asyncio.sleep(polling_interval)
else: # Handle other error statuses from status GET
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
# Log the error, but continue polling unless it's a fatal client error (4xx other than 429)
# or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body)
# For a 504 like you saw, we might want to retry a few times then give up.
# For now, this will return None on non-200/202, which your test script will catch.
return None
except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
await asyncio.sleep(polling_interval)
except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
await asyncio.sleep(polling_interval)
except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
return None