refactor, full run

This commit is contained in:
based-tachikoma 2025-05-20 20:11:52 -07:00
parent de9dfff221
commit 1ee67de035
12 changed files with 1039 additions and 1127 deletions

View file

@ -1,5 +1,3 @@
"""RFDiffusion API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
@ -10,7 +8,6 @@ from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
@ -27,7 +24,7 @@ async def call_rfdiffusion(
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM RFDiffusion API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
@ -38,29 +35,26 @@ async def call_rfdiffusion(
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"diffusion_steps": diffusion_steps
}
# Add optional parameters if provided
if contigs:
data["contigs"] = contigs
if hotspot_res:
data["hotspot_res"] = hotspot_res
try:
async with aiohttp.ClientSession() as session:
async with session.post(
@ -69,11 +63,9 @@ async def call_rfdiffusion(
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"RFDiffusion job submitted, request ID: {req_id}")
@ -95,7 +87,7 @@ async def call_rfdiffusion(
except Exception as e:
logger.error(f"Error calling RFDiffusion API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
@ -105,14 +97,14 @@ async def _poll_job_status(
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
@ -125,11 +117,9 @@ async def _poll_job_status(
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"RFDiffusion job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
@ -139,4 +129,4 @@ async def _poll_job_status(
return None
except Exception as e:
logger.error(f"Error polling RFDiffusion job status: {e}")
return None
return None