rfdiffusion fix

This commit is contained in:
based-tachikoma 2025-05-19 19:42:48 -07:00
parent 4d9bec44c6
commit de9dfff221
8 changed files with 1253 additions and 104 deletions

View file

@ -0,0 +1 @@
"""Protein design model API modules."""

View file

@ -0,0 +1,150 @@
"""AlphaFold2 API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_alphafold2(
sequence: str,
api_key: str,
algorithm: str = "mmseqs2",
e_value: float = 0.0001,
iterations: int = 1,
databases: List[str] = ["small_bfd"],
relax_prediction: bool = False,
skip_template_search: bool = True,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 600, # Increased timeout
max_retries: int = 3 # Added retry parameter
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2 API.
Args:
sequence: Protein sequence in one-letter code
api_key: NVIDIA NIM API key
algorithm: MSA search algorithm, "mmseqs2" or "jackhmmer"
e_value: E-value threshold for template search
iterations: Number of iterations for template search
databases: List of databases to search
relax_prediction: Whether to relax the prediction
skip_template_search: Whether to skip template search
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"sequence": sequence,
"algorithm": algorithm,
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction,
"skip_template_search": skip_template_search
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2 job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling AlphaFold2 API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
import traceback
logger.error(f"Error calling AlphaFold2 API: {e}")
logger.error(traceback.format_exc())
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"AlphaFold2 job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"AlphaFold2 job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking AlphaFold2 job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling AlphaFold2 job status: {e}")
return None

View file

@ -0,0 +1,366 @@
"""AlphaFold2-Multimer API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Tuple
from pathlib import Path
import zipfile
import io
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
# Helper functions
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
"""
Splits a string containing concatenated PDB file contents.
Assumes models are separated by "ENDMDL" or just "END" for the last/single model.
"""
pdbs = []
current_pdb_lines = []
if not concatenated_pdb_str:
return []
for line in concatenated_pdb_str.splitlines(keepends=True):
current_pdb_lines.append(line)
if line.startswith("ENDMDL") or line.startswith("END "):
pdbs.append("".join(current_pdb_lines).strip())
current_pdb_lines = []
if current_pdb_lines:
remaining_pdb = "".join(current_pdb_lines).strip()
if remaining_pdb:
pdbs.append(remaining_pdb)
return [pdb for pdb in pdbs if pdb]
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
"""
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
Returns:
A tuple containing:
- average_plddt (float): Average pLDDT over all C-alpha atoms.
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
"""
total_plddt = 0.0
ca_atom_count = 0
plddt_scores_per_ca: List[float] = []
plddt_scores_per_chain: Dict[str, List[float]] = {}
for line in pdb_string.splitlines():
if line.startswith("ATOM"):
atom_name = line[12:16].strip()
if atom_name == "CA":
try:
plddt_value = float(line[60:66].strip())
total_plddt += plddt_value
plddt_scores_per_ca.append(plddt_value)
ca_atom_count += 1
chain_id = line[21:22].strip()
if chain_id not in plddt_scores_per_chain:
plddt_scores_per_chain[chain_id] = []
plddt_scores_per_chain[chain_id].append(plddt_value)
except ValueError:
pass
except IndexError:
pass
if ca_atom_count == 0:
return 0.0, [], {}
average_plddt = total_plddt / ca_atom_count
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
async def _process_nvidia_zip_output(
zip_content: bytes,
output_prefix: str,
response_headers: Optional[Dict[str, str]] = None
) -> Optional[Dict[str, Any]]:
"""
Processes the ZIP file content from NVIDIA NIM.
- Expects a .response file with concatenated PDBs, or individual PDB files.
- Extracts PDBs and calculates pLDDT scores for each structure.
- Returns paths to saved files and calculated pLDDT scores.
"""
output_dir = Path(f"./{output_prefix}_results")
output_dir.mkdir(parents=True, exist_ok=True)
# Save the original ZIP file
zip_file_path = output_dir / f"{output_prefix}.zip"
with open(zip_file_path, 'wb') as f:
f.write(zip_content)
logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}")
# Initialize the results dictionary that call_alphafold2_multimer will return
results: Dict[str, Any] = {
"zip_file_path": str(zip_file_path), # Path to the saved original ZIP
"structures": [], # List to hold info for each PDB structure found
# We will NOT be trying to parse iptm/ptm from ranking_debug.json
"iptm_score": None, # Explicitly None, or remove if not needed at all
"ptm_score": None, # Explicitly None, or remove
# Optional: Store path to the .response file if it exists and is processed
"extracted_response_file_path": None,
}
pdb_strings_to_process = []
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zf:
# First, check for a ".response" file with concatenated PDBs
response_file_name = None
for member_name in zf.namelist():
if member_name.lower().endswith(".response"):
response_file_name = member_name
break
if response_file_name:
logger.info(f"Found concatenated response file in ZIP: {response_file_name}")
response_data = zf.read(response_file_name)
response_content_str = response_data.decode('utf-8', errors='replace')
# Save the raw .response file
extracted_response_file_path = output_dir / Path(response_file_name).name
with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp:
f_resp.write(response_content_str)
results["extracted_response_file_path"] = str(extracted_response_file_path)
logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}")
pdb_strings_to_process.extend(_split_pdb_content(response_content_str))
else:
# If no ".response" file, look for individual PDB files
logger.info("No .response file found. Looking for individual .pdb files in ZIP.")
for member_name in zf.namelist():
if member_name.lower().endswith(".pdb"):
logger.info(f"Found individual PDB file in ZIP: {member_name}")
pdb_content_bytes = zf.read(member_name)
pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace'))
if not pdb_strings_to_process:
logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).")
return results # Return with empty structures list
logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.")
for i, pdb_str in enumerate(pdb_strings_to_process):
if not pdb_str.strip(): # Skip empty PDB strings
logger.debug(f"Skipping empty PDB string at index {i}.")
continue
structure_data: Dict[str, Any] = {
"model_index": i, # 0-indexed based on order found
"pdb_content": pdb_str # Store the raw PDB string
}
# Calculate pLDDT scores using your existing function
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
structure_data["average_plddt"] = avg_plddt
structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs
structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT]
# Calculate average pLDDT for each chain (already in your previous code, good to keep)
avg_plddt_per_chain = {}
for chain_id, chain_plddts in plddts_by_chain.items():
if chain_plddts: # Avoid division by zero
avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts)
else:
avg_plddt_per_chain[chain_id] = 0.0
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
# Save the individual PDB string to a file
pdb_file_name_stem = Path(output_prefix).stem
# Suffix for rank if multiple models found, otherwise simpler name
rank_suffix = f"_model_{i}" # Consistent naming for multiple models
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
try:
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
f_pdb.write(pdb_str)
structure_data["saved_pdb_path"] = str(pdb_file_path)
logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
except Exception as e_write:
logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}")
structure_data["saved_pdb_path"] = None
results["structures"].append(structure_data)
if results["structures"]:
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.")
except zipfile.BadZipFile:
logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.")
# results dictionary is already initialized, will be returned as is, potentially empty.
except Exception as e:
logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True)
# As above, return results which might be partially filled or empty.
return results
async def call_alphafold2_multimer(
sequences: List[str],
api_key: str,
algorithm: str = "jackhmmer",
e_value: float = 0.0001,
iterations: int = 1,
databases: List[str] = ["uniref90", "small_bfd", "mgnify"],
relax_prediction: bool = True,
selected_models: Optional[List[int]] = None,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 30,
timeout: int = 3600
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM AlphaFold2-Multimer API.
Returns a dictionary structured by _process_nvidia_zip_output.
"""
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
data: Dict[str, Any] = {
"sequences": sequences,
"algorithm": algorithm,
"e_value": e_value,
"iterations": iterations,
"databases": databases,
"relax_prediction": relax_prediction
}
if selected_models is not None:
data["selected_models"] = selected_models
logger.info(f"Using selected_models: {selected_models}")
try:
initial_post_timeout = min(timeout, 600)
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=initial_post_timeout
) as response:
if response.status == 200:
logger.info("AlphaFold2-Multimer job completed synchronously.")
content = await response.read()
return await _process_nvidia_zip_output(
zip_content=content,
output_prefix="alphafold2_multimer_sync_output",
response_headers=response.headers
)
elif response.status == 202:
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
overall_timeout=timeout
)
else:
logger.error("No request ID in 202 response headers")
return None
else: # Handle other error statuses from POST
logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except asyncio.TimeoutError:
logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).")
return None
except Exception as e:
logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True)
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 30,
overall_timeout: int = 3600
) -> Optional[Dict[str, Any]]:
start_time = asyncio.get_event_loop().time()
# Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it,
# or a fixed reasonably long duration.
# The NVCF-POLL-SECONDS in the POST header is 300s.
# The GET request to /status should also ideally respect a similar long-poll duration from the server.
status_check_timeout = 330 # seconds (e.g., 5.5 minutes)
logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
while True:
current_loop_time = asyncio.get_event_loop().time()
elapsed_time = current_loop_time - start_time
if elapsed_time > overall_timeout:
logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.")
return None
remaining_time_for_overall_timeout = overall_timeout - elapsed_time
current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout)
if current_status_check_timeout <= 0:
logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.")
return None
try:
async with aiohttp.ClientSession() as session:
logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.")
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=current_status_check_timeout
) as response:
if response.status == 200:
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
zip_content_bytes = await response.read()
return await _process_nvidia_zip_output(
zip_content=zip_content_bytes,
output_prefix=f"alphafold2_multimer_output_{req_id}",
response_headers=response.headers
)
elif response.status == 202:
try:
job_status_json = await response.json()
percent_complete = job_status_json.get('percentComplete', 'N/A')
status_message = job_status_json.get('status', 'running')
logger.debug(
f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s."
)
except (aiohttp.ContentTypeError, json.JSONDecodeError):
logger.debug(
f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s."
)
await asyncio.sleep(polling_interval)
else: # Handle other error statuses from status GET
logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
# Log the error, but continue polling unless it's a fatal client error (4xx other than 429)
# or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body)
# For a 504 like you saw, we might want to retry a few times then give up.
# For now, this will return None on non-200/202, which your test script will catch.
return None
except asyncio.TimeoutError:
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
await asyncio.sleep(polling_interval)
except aiohttp.ClientError as e:
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
await asyncio.sleep(polling_interval)
except Exception as e:
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
return None

View file

@ -0,0 +1,138 @@
"""ProteinMPNN API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_proteinmpnn(
input_pdb: str,
api_key: str,
ca_only: bool = False,
use_soluble_model: bool = False,
sampling_temp: List[float] = [0.1],
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM ProteinMPNN API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
ca_only: Whether to use only Cα atoms
use_soluble_model: Whether to use the soluble model
sampling_temp: List of sampling temperatures
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"ca_only": ca_only,
"use_soluble_model": use_soluble_model,
"sampling_temp": sampling_temp
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"ProteinMPNN job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling ProteinMPNN API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error calling ProteinMPNN API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"ProteinMPNN job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"ProteinMPNN job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking ProteinMPNN job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling ProteinMPNN job status: {e}")
return None

View file

@ -0,0 +1,142 @@
"""RFDiffusion API integration for NVIDIA NIM."""
import os
import logging
import aiohttp
import json
import asyncio
from typing import Dict, List, Any, Optional, Union
from pathlib import Path
logger = logging.getLogger(__name__)
# Default URL
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate"
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
async def call_rfdiffusion(
input_pdb: str,
api_key: str,
contigs: str = None,
hotspot_res: List[str] = None,
diffusion_steps: int = 15,
url: str = DEFAULT_URL,
status_url: str = DEFAULT_STATUS_URL,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Call the NVIDIA NIM RFDiffusion API.
Args:
input_pdb: PDB structure as a string
api_key: NVIDIA NIM API key
contigs: Contig string (e.g. "A20-60/0 50-100")
hotspot_res: List of hotspot residues (e.g. ["A50","A51"])
diffusion_steps: Number of diffusion steps
url: API endpoint URL
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
API response or None on failure
"""
# Prepare headers
headers = {
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"NVCF-POLL-SECONDS": "300",
}
# Prepare payload
data = {
"input_pdb": input_pdb,
"diffusion_steps": diffusion_steps
}
# Add optional parameters if provided
if contigs:
data["contigs"] = contigs
if hotspot_res:
data["hotspot_res"] = hotspot_res
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
json=data,
headers=headers,
timeout=timeout
) as response:
# Check status code
if response.status == 200:
return await response.json()
elif response.status == 202:
# Asynchronous job, get job ID
req_id = response.headers.get("nvcf-reqid")
if req_id:
logger.info(f"RFDiffusion job submitted, request ID: {req_id}")
return await _poll_job_status(
req_id=req_id,
headers=headers,
status_url=status_url,
polling_interval=polling_interval,
timeout=timeout
)
else:
logger.error("No request ID in response headers")
return None
else:
logger.error(f"Error calling RFDiffusion API: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error calling RFDiffusion API: {e}")
return None
async def _poll_job_status(
req_id: str,
headers: Dict[str, str],
status_url: str,
polling_interval: int = 10,
timeout: int = 60
) -> Optional[Dict[str, Any]]:
"""
Poll the status endpoint until the job completes.
Args:
req_id: The request ID to check
headers: Request headers
status_url: Status URL for checking job completion
polling_interval: Seconds between status checks
timeout: Request timeout in seconds
Returns:
The final response or None on failure
"""
while True:
try:
async with aiohttp.ClientSession() as session:
async with session.get(
f"{status_url}/{req_id}",
headers=headers,
timeout=timeout
) as response:
if response.status == 200:
# Job completed
logger.info(f"RFDiffusion job {req_id} completed")
return await response.json()
elif response.status == 202:
# Job still running
logger.debug(f"RFDiffusion job {req_id} still running, polling...")
await asyncio.sleep(polling_interval)
else:
logger.error(f"Error checking RFDiffusion job status: {response.status}")
text = await response.text()
logger.error(f"Response: {text}")
return None
except Exception as e:
logger.error(f"Error polling RFDiffusion job status: {e}")
return None