refactor, full run

This commit is contained in:
based-tachikoma 2025-05-20 20:11:52 -07:00
parent de9dfff221
commit 1ee67de035
12 changed files with 1039 additions and 1127 deletions

View file

@ -1,5 +1,3 @@
"""AlphaFold2 API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
@ -10,7 +8,6 @@ from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
@ -26,12 +23,12 @@ async def call_alphafold2(
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 600, # Increased timeout
max_retries: int = 3 # Added retry parameter
timeout: int = 600,
max_retries: int = 3
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2 API.
Args:
sequence: Protein sequence in one-letter code
api_key: NVIDIA NIM API key
@ -45,18 +42,16 @@ async def call_alphafold2(
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"sequence": sequence,
"algorithm": algorithm,
@ -66,7 +61,7 @@ async def call_alphafold2(
"relax_prediction": relax_prediction,
"skip_template_search": skip_template_search
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
@ -75,11 +70,9 @@ async def call_alphafold2(
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2 job submitted, request ID: {req_id}")
@ -103,7 +96,7 @@ async def call_alphafold2(
logger.error(f"Error calling AlphaFold2 API: {e}")
logger.error(traceback.format_exc())
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
@ -113,14 +106,14 @@ async def _poll_job_status(
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
@ -133,11 +126,9 @@ async def _poll_job_status(
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"AlphaFold2 job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
@ -147,4 +138,4 @@ async def _poll_job_status(
return None
except Exception as e:
logger.error(f"Error polling AlphaFold2 job status: {e}")
return None
return None