atropos/environments/hack0/protein_design_env/models/proteinmpnn.py
2025-05-20 20:12:59 -07:00

138 lines
No EOL
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ProteinMPNN API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_proteinmpnn(
input_pdb: str,
api_key: str,
ca_only: bool = False,
use_soluble_model: bool = False,
sampling_temp: List[float] = [0.1],
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM ProteinMPNN API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
ca_only: Whether to use only Cα atoms
use_soluble_model: Whether to use the soluble model
sampling_temp: List of sampling temperatures
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"ca_only": ca_only,
"use_soluble_model": use_soluble_model,
"sampling_temp": sampling_temp
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"ProteinMPNN job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling ProteinMPNN API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error calling ProteinMPNN API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"ProteinMPNN job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking ProteinMPNN job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling ProteinMPNN job status: {e}")
return None