diff --git a/environments/hack0/protein_design_env/models/__init__.py b/environments/hack0/protein_design_env/models/__init__.py new file mode 100644 index 00000000..2cd7109b --- /dev/null +++ b/environments/hack0/protein_design_env/models/__init__.py @@ -0,0 +1 @@ +"""Protein design model API modules.""" \ No newline at end of file diff --git a/environments/hack0/protein_design_env/models/alphafold2.py b/environments/hack0/protein_design_env/models/alphafold2.py new file mode 100644 index 00000000..038ec06c --- /dev/null +++ b/environments/hack0/protein_design_env/models/alphafold2.py @@ -0,0 +1,150 @@ +"""AlphaFold2 API integration for NVIDIA NIM.""" + +import os +import logging +import aiohttp +import json +import asyncio +from typing import Dict, List, Any, Optional +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Default URL +DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2" +DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + +async def call_alphafold2( + sequence: str, + api_key: str, + algorithm: str = "mmseqs2", + e_value: float = 0.0001, + iterations: int = 1, + databases: List[str] = ["small_bfd"], + relax_prediction: bool = False, + skip_template_search: bool = True, + url: str = DEFAULT_URL, + status_url: str = DEFAULT_STATUS_URL, + polling_interval: int = 10, + timeout: int = 600, # Increased timeout + max_retries: int = 3 # Added retry parameter +) -> Optional[Dict[str, Any]]: + """ + Call the NVIDIA NIM AlphaFold2 API. + + Args: + sequence: Protein sequence in one-letter code + api_key: NVIDIA NIM API key + algorithm: MSA search algorithm, "mmseqs2" or "jackhmmer" + e_value: E-value threshold for template search + iterations: Number of iterations for template search + databases: List of databases to search + relax_prediction: Whether to relax the prediction + skip_template_search: Whether to skip template search + url: API endpoint URL + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + API response or None on failure + """ + # Prepare headers + headers = { + "content-type": "application/json", + "Authorization": f"Bearer {api_key}", + "NVCF-POLL-SECONDS": "300", + } + + # Prepare payload + data = { + "sequence": sequence, + "algorithm": algorithm, + "e_value": e_value, + "iterations": iterations, + "databases": databases, + "relax_prediction": relax_prediction, + "skip_template_search": skip_template_search + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=data, + headers=headers, + timeout=timeout + ) as response: + # Check status code + if response.status == 200: + return await response.json() + elif response.status == 202: + # Asynchronous job, get job ID + req_id = response.headers.get("nvcf-reqid") + if req_id: + logger.info(f"AlphaFold2 job submitted, request ID: {req_id}") + return await _poll_job_status( + req_id=req_id, + headers=headers, + status_url=status_url, + polling_interval=polling_interval, + timeout=timeout + ) + else: + logger.error("No request ID in response headers") + return None + else: + logger.error(f"Error calling AlphaFold2 API: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + import traceback + logger.error(f"Error calling AlphaFold2 API: {e}") + logger.error(traceback.format_exc()) + return None + +async def _poll_job_status( + req_id: str, + headers: Dict[str, str], + status_url: str, + polling_interval: int = 10, + timeout: int = 60 +) -> Optional[Dict[str, Any]]: + """ + Poll the status endpoint until the job completes. + + Args: + req_id: The request ID to check + headers: Request headers + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + The final response or None on failure + """ + while True: + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{status_url}/{req_id}", + headers=headers, + timeout=timeout + ) as response: + if response.status == 200: + # Job completed + logger.info(f"AlphaFold2 job {req_id} completed") + return await response.json() + elif response.status == 202: + # Job still running + logger.debug(f"AlphaFold2 job {req_id} still running, polling...") + await asyncio.sleep(polling_interval) + else: + logger.error(f"Error checking AlphaFold2 job status: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + logger.error(f"Error polling AlphaFold2 job status: {e}") + return None \ No newline at end of file diff --git a/environments/hack0/protein_design_env/models/alphafold2_multimer.py b/environments/hack0/protein_design_env/models/alphafold2_multimer.py new file mode 100644 index 00000000..16e0a894 --- /dev/null +++ b/environments/hack0/protein_design_env/models/alphafold2_multimer.py @@ -0,0 +1,366 @@ +"""AlphaFold2-Multimer API integration for NVIDIA NIM.""" + +import os +import logging +import aiohttp +import json +import asyncio +from typing import Dict, List, Any, Optional, Tuple +from pathlib import Path +import zipfile +import io + +logger = logging.getLogger(__name__) + +# Default URL +DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer" +DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + +# Helper functions +def _split_pdb_content(concatenated_pdb_str: str) -> List[str]: + """ + Splits a string containing concatenated PDB file contents. + Assumes models are separated by "ENDMDL" or just "END" for the last/single model. + """ + pdbs = [] + current_pdb_lines = [] + if not concatenated_pdb_str: + return [] + + for line in concatenated_pdb_str.splitlines(keepends=True): + current_pdb_lines.append(line) + if line.startswith("ENDMDL") or line.startswith("END "): + pdbs.append("".join(current_pdb_lines).strip()) + current_pdb_lines = [] + + if current_pdb_lines: + remaining_pdb = "".join(current_pdb_lines).strip() + if remaining_pdb: + pdbs.append(remaining_pdb) + + return [pdb for pdb in pdbs if pdb] + + +def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]: + """ + Calculates the average pLDDT score from a PDB string for C-alpha atoms. + Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain. + + Returns: + A tuple containing: + - average_plddt (float): Average pLDDT over all C-alpha atoms. + - plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom. + - plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs. + """ + total_plddt = 0.0 + ca_atom_count = 0 + plddt_scores_per_ca: List[float] = [] + plddt_scores_per_chain: Dict[str, List[float]] = {} + + for line in pdb_string.splitlines(): + if line.startswith("ATOM"): + atom_name = line[12:16].strip() + if atom_name == "CA": + try: + plddt_value = float(line[60:66].strip()) + total_plddt += plddt_value + plddt_scores_per_ca.append(plddt_value) + ca_atom_count += 1 + + chain_id = line[21:22].strip() + if chain_id not in plddt_scores_per_chain: + plddt_scores_per_chain[chain_id] = [] + plddt_scores_per_chain[chain_id].append(plddt_value) + + except ValueError: + pass + except IndexError: + pass + + if ca_atom_count == 0: + return 0.0, [], {} + + average_plddt = total_plddt / ca_atom_count + return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain + +async def _process_nvidia_zip_output( + zip_content: bytes, + output_prefix: str, + response_headers: Optional[Dict[str, str]] = None +) -> Optional[Dict[str, Any]]: + """ + Processes the ZIP file content from NVIDIA NIM. + - Expects a .response file with concatenated PDBs, or individual PDB files. + - Extracts PDBs and calculates pLDDT scores for each structure. + - Returns paths to saved files and calculated pLDDT scores. + """ + output_dir = Path(f"./{output_prefix}_results") + output_dir.mkdir(parents=True, exist_ok=True) + + # Save the original ZIP file + zip_file_path = output_dir / f"{output_prefix}.zip" + with open(zip_file_path, 'wb') as f: + f.write(zip_content) + logger.info(f"Downloaded and saved original ZIP file to {zip_file_path}") + + # Initialize the results dictionary that call_alphafold2_multimer will return + results: Dict[str, Any] = { + "zip_file_path": str(zip_file_path), # Path to the saved original ZIP + "structures": [], # List to hold info for each PDB structure found + # We will NOT be trying to parse iptm/ptm from ranking_debug.json + "iptm_score": None, # Explicitly None, or remove if not needed at all + "ptm_score": None, # Explicitly None, or remove + # Optional: Store path to the .response file if it exists and is processed + "extracted_response_file_path": None, + } + + pdb_strings_to_process = [] + + try: + with zipfile.ZipFile(io.BytesIO(zip_content)) as zf: + # First, check for a ".response" file with concatenated PDBs + response_file_name = None + for member_name in zf.namelist(): + if member_name.lower().endswith(".response"): + response_file_name = member_name + break + + if response_file_name: + logger.info(f"Found concatenated response file in ZIP: {response_file_name}") + response_data = zf.read(response_file_name) + response_content_str = response_data.decode('utf-8', errors='replace') + + # Save the raw .response file + extracted_response_file_path = output_dir / Path(response_file_name).name + with open(extracted_response_file_path, 'w', encoding='utf-8', errors='replace') as f_resp: + f_resp.write(response_content_str) + results["extracted_response_file_path"] = str(extracted_response_file_path) + logger.info(f"Saved raw content of '{response_file_name}' to {extracted_response_file_path}") + + pdb_strings_to_process.extend(_split_pdb_content(response_content_str)) + else: + # If no ".response" file, look for individual PDB files + logger.info("No .response file found. Looking for individual .pdb files in ZIP.") + for member_name in zf.namelist(): + if member_name.lower().endswith(".pdb"): + logger.info(f"Found individual PDB file in ZIP: {member_name}") + pdb_content_bytes = zf.read(member_name) + pdb_strings_to_process.append(pdb_content_bytes.decode('utf-8', errors='replace')) + + if not pdb_strings_to_process: + logger.warning(f"No PDB content found in ZIP archive {zip_file_path} (either as .response or individual .pdb files).") + return results # Return with empty structures list + + logger.info(f"Found {len(pdb_strings_to_process)} PDB structure(s) to process.") + + for i, pdb_str in enumerate(pdb_strings_to_process): + if not pdb_str.strip(): # Skip empty PDB strings + logger.debug(f"Skipping empty PDB string at index {i}.") + continue + + structure_data: Dict[str, Any] = { + "model_index": i, # 0-indexed based on order found + "pdb_content": pdb_str # Store the raw PDB string + } + + # Calculate pLDDT scores using your existing function + avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str) + + structure_data["average_plddt"] = avg_plddt + structure_data["plddt_scores_per_ca_residue"] = plddts_per_ca_residue # List of pLDDTs for CAs + structure_data["plddt_scores_per_chain"] = plddts_by_chain # Dict: chain_id -> List[pLDDT] + + # Calculate average pLDDT for each chain (already in your previous code, good to keep) + avg_plddt_per_chain = {} + for chain_id, chain_plddts in plddts_by_chain.items(): + if chain_plddts: # Avoid division by zero + avg_plddt_per_chain[chain_id] = sum(chain_plddts) / len(chain_plddts) + else: + avg_plddt_per_chain[chain_id] = 0.0 + structure_data["average_plddt_per_chain"] = avg_plddt_per_chain + + # Save the individual PDB string to a file + pdb_file_name_stem = Path(output_prefix).stem + # Suffix for rank if multiple models found, otherwise simpler name + rank_suffix = f"_model_{i}" # Consistent naming for multiple models + pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb" + + try: + with open(pdb_file_path, "w", encoding='utf-8') as f_pdb: + f_pdb.write(pdb_str) + structure_data["saved_pdb_path"] = str(pdb_file_path) + logger.info(f"Saved PDB model {i} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}") + except Exception as e_write: + logger.error(f"Failed to write PDB file {pdb_file_path}: {e_write}") + structure_data["saved_pdb_path"] = None + + results["structures"].append(structure_data) + + if results["structures"]: + logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures.") + + except zipfile.BadZipFile: + logger.error(f"Failed to process ZIP: {zip_file_path} is not a valid ZIP file.") + # results dictionary is already initialized, will be returned as is, potentially empty. + except Exception as e: + logger.error(f"An error occurred during ZIP processing of {zip_file_path}: {e}", exc_info=True) + # As above, return results which might be partially filled or empty. + + return results + +async def call_alphafold2_multimer( + sequences: List[str], + api_key: str, + algorithm: str = "jackhmmer", + e_value: float = 0.0001, + iterations: int = 1, + databases: List[str] = ["uniref90", "small_bfd", "mgnify"], + relax_prediction: bool = True, + selected_models: Optional[List[int]] = None, + url: str = DEFAULT_URL, + status_url: str = DEFAULT_STATUS_URL, + polling_interval: int = 30, + timeout: int = 3600 +) -> Optional[Dict[str, Any]]: + """ + Call the NVIDIA NIM AlphaFold2-Multimer API. + Returns a dictionary structured by _process_nvidia_zip_output. + """ + headers = { + "content-type": "application/json", + "Authorization": f"Bearer {api_key}", + "NVCF-POLL-SECONDS": "300", + } + data: Dict[str, Any] = { + "sequences": sequences, + "algorithm": algorithm, + "e_value": e_value, + "iterations": iterations, + "databases": databases, + "relax_prediction": relax_prediction + } + if selected_models is not None: + data["selected_models"] = selected_models + logger.info(f"Using selected_models: {selected_models}") + + try: + initial_post_timeout = min(timeout, 600) + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=data, + headers=headers, + timeout=initial_post_timeout + ) as response: + if response.status == 200: + logger.info("AlphaFold2-Multimer job completed synchronously.") + content = await response.read() + return await _process_nvidia_zip_output( + zip_content=content, + output_prefix="alphafold2_multimer_sync_output", + response_headers=response.headers + ) + elif response.status == 202: + req_id = response.headers.get("nvcf-reqid") + if req_id: + logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}") + return await _poll_job_status( + req_id=req_id, + headers=headers, + status_url=status_url, + polling_interval=polling_interval, + overall_timeout=timeout + ) + else: + logger.error("No request ID in 202 response headers") + return None + else: # Handle other error statuses from POST + logger.error(f"Error calling AlphaFold2-Multimer API (POST): {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except asyncio.TimeoutError: + logger.error(f"Timeout during AlphaFold2-Multimer API (initial POST).") + return None + except Exception as e: + logger.error(f"Exception during AlphaFold2-Multimer API call (initial POST): {e}", exc_info=True) + return None + +async def _poll_job_status( + req_id: str, + headers: Dict[str, str], + status_url: str, + polling_interval: int = 30, + overall_timeout: int = 3600 +) -> Optional[Dict[str, Any]]: + start_time = asyncio.get_event_loop().time() + # Allow status checks to wait longer, e.g., slightly more than NVCF-POLL-SECONDS if you use it, + # or a fixed reasonably long duration. + # The NVCF-POLL-SECONDS in the POST header is 300s. + # The GET request to /status should also ideally respect a similar long-poll duration from the server. + status_check_timeout = 330 # seconds (e.g., 5.5 minutes) + logger.info(f"Polling job {req_id}. Status check timeout: {status_check_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s") + + while True: + current_loop_time = asyncio.get_event_loop().time() + elapsed_time = current_loop_time - start_time + + if elapsed_time > overall_timeout: + logger.error(f"Overall polling timeout of {overall_timeout}s exceeded for job {req_id}.") + return None + + remaining_time_for_overall_timeout = overall_timeout - elapsed_time + current_status_check_timeout = min(status_check_timeout, remaining_time_for_overall_timeout) + if current_status_check_timeout <= 0: + logger.error(f"Not enough time left for another status check for job {req_id} within overall_timeout.") + return None + + try: + async with aiohttp.ClientSession() as session: + logger.debug(f"Checking status for {req_id} with timeout {current_status_check_timeout}s.") + async with session.get( + f"{status_url}/{req_id}", + headers=headers, + timeout=current_status_check_timeout + ) as response: + if response.status == 200: + logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).") + logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}") + logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}") + zip_content_bytes = await response.read() + return await _process_nvidia_zip_output( + zip_content=zip_content_bytes, + output_prefix=f"alphafold2_multimer_output_{req_id}", + response_headers=response.headers + ) + elif response.status == 202: + try: + job_status_json = await response.json() + percent_complete = job_status_json.get('percentComplete', 'N/A') + status_message = job_status_json.get('status', 'running') + logger.debug( + f"Job {req_id} status: {status_message} ({percent_complete}% complete). Polling again in {polling_interval}s." + ) + except (aiohttp.ContentTypeError, json.JSONDecodeError): + logger.debug( + f"Job {req_id} still running (202 status, non-JSON/malformed JSON body). Polling again in {polling_interval}s." + ) + await asyncio.sleep(polling_interval) + else: # Handle other error statuses from status GET + logger.error(f"Error checking AlphaFold2-Multimer job status {req_id}: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + # Log the error, but continue polling unless it's a fatal client error (4xx other than 429) + # or if the server explicitly indicates failure (e.g. 500, or a 200 with error status in body) + # For a 504 like you saw, we might want to retry a few times then give up. + # For now, this will return None on non-200/202, which your test script will catch. + return None + except asyncio.TimeoutError: + logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.") + await asyncio.sleep(polling_interval) + except aiohttp.ClientError as e: + logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True) + await asyncio.sleep(polling_interval) + except Exception as e: + logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True) + return None \ No newline at end of file diff --git a/environments/hack0/protein_design_env/models/proteinmpnn.py b/environments/hack0/protein_design_env/models/proteinmpnn.py new file mode 100644 index 00000000..179c52ec --- /dev/null +++ b/environments/hack0/protein_design_env/models/proteinmpnn.py @@ -0,0 +1,138 @@ +"""ProteinMPNN API integration for NVIDIA NIM.""" + +import os +import logging +import aiohttp +import json +import asyncio +from typing import Dict, List, Any, Optional, Union +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Default URL +DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/proteinmpnn/predict" +DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + +async def call_proteinmpnn( + input_pdb: str, + api_key: str, + ca_only: bool = False, + use_soluble_model: bool = False, + sampling_temp: List[float] = [0.1], + url: str = DEFAULT_URL, + status_url: str = DEFAULT_STATUS_URL, + polling_interval: int = 10, + timeout: int = 60 +) -> Optional[Dict[str, Any]]: + """ + Call the NVIDIA NIM ProteinMPNN API. + + Args: + input_pdb: PDB structure as a string + api_key: NVIDIA NIM API key + ca_only: Whether to use only Cα atoms + use_soluble_model: Whether to use the soluble model + sampling_temp: List of sampling temperatures + url: API endpoint URL + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + API response or None on failure + """ + # Prepare headers + headers = { + "content-type": "application/json", + "Authorization": f"Bearer {api_key}", + "NVCF-POLL-SECONDS": "300", + } + + # Prepare payload + data = { + "input_pdb": input_pdb, + "ca_only": ca_only, + "use_soluble_model": use_soluble_model, + "sampling_temp": sampling_temp + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=data, + headers=headers, + timeout=timeout + ) as response: + # Check status code + if response.status == 200: + return await response.json() + elif response.status == 202: + # Asynchronous job, get job ID + req_id = response.headers.get("nvcf-reqid") + if req_id: + logger.info(f"ProteinMPNN job submitted, request ID: {req_id}") + return await _poll_job_status( + req_id=req_id, + headers=headers, + status_url=status_url, + polling_interval=polling_interval, + timeout=timeout + ) + else: + logger.error("No request ID in response headers") + return None + else: + logger.error(f"Error calling ProteinMPNN API: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + logger.error(f"Error calling ProteinMPNN API: {e}") + return None + +async def _poll_job_status( + req_id: str, + headers: Dict[str, str], + status_url: str, + polling_interval: int = 10, + timeout: int = 60 +) -> Optional[Dict[str, Any]]: + """ + Poll the status endpoint until the job completes. + + Args: + req_id: The request ID to check + headers: Request headers + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + The final response or None on failure + """ + while True: + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{status_url}/{req_id}", + headers=headers, + timeout=timeout + ) as response: + if response.status == 200: + # Job completed + logger.info(f"ProteinMPNN job {req_id} completed") + return await response.json() + elif response.status == 202: + # Job still running + logger.debug(f"ProteinMPNN job {req_id} still running, polling...") + await asyncio.sleep(polling_interval) + else: + logger.error(f"Error checking ProteinMPNN job status: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + logger.error(f"Error polling ProteinMPNN job status: {e}") + return None \ No newline at end of file diff --git a/environments/hack0/protein_design_env/models/rfdiffusion.py b/environments/hack0/protein_design_env/models/rfdiffusion.py new file mode 100644 index 00000000..38df3f7b --- /dev/null +++ b/environments/hack0/protein_design_env/models/rfdiffusion.py @@ -0,0 +1,142 @@ +"""RFDiffusion API integration for NVIDIA NIM.""" + +import os +import logging +import aiohttp +import json +import asyncio +from typing import Dict, List, Any, Optional, Union +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Default URL +DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/ipd/rfdiffusion/generate" +DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status" + +async def call_rfdiffusion( + input_pdb: str, + api_key: str, + contigs: str = None, + hotspot_res: List[str] = None, + diffusion_steps: int = 15, + url: str = DEFAULT_URL, + status_url: str = DEFAULT_STATUS_URL, + polling_interval: int = 10, + timeout: int = 60 +) -> Optional[Dict[str, Any]]: + """ + Call the NVIDIA NIM RFDiffusion API. + + Args: + input_pdb: PDB structure as a string + api_key: NVIDIA NIM API key + contigs: Contig string (e.g. "A20-60/0 50-100") + hotspot_res: List of hotspot residues (e.g. ["A50","A51"]) + diffusion_steps: Number of diffusion steps + url: API endpoint URL + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + API response or None on failure + """ + # Prepare headers + headers = { + "content-type": "application/json", + "Authorization": f"Bearer {api_key}", + "NVCF-POLL-SECONDS": "300", + } + + # Prepare payload + data = { + "input_pdb": input_pdb, + "diffusion_steps": diffusion_steps + } + + # Add optional parameters if provided + if contigs: + data["contigs"] = contigs + if hotspot_res: + data["hotspot_res"] = hotspot_res + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=data, + headers=headers, + timeout=timeout + ) as response: + # Check status code + if response.status == 200: + return await response.json() + elif response.status == 202: + # Asynchronous job, get job ID + req_id = response.headers.get("nvcf-reqid") + if req_id: + logger.info(f"RFDiffusion job submitted, request ID: {req_id}") + return await _poll_job_status( + req_id=req_id, + headers=headers, + status_url=status_url, + polling_interval=polling_interval, + timeout=timeout + ) + else: + logger.error("No request ID in response headers") + return None + else: + logger.error(f"Error calling RFDiffusion API: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + logger.error(f"Error calling RFDiffusion API: {e}") + return None + +async def _poll_job_status( + req_id: str, + headers: Dict[str, str], + status_url: str, + polling_interval: int = 10, + timeout: int = 60 +) -> Optional[Dict[str, Any]]: + """ + Poll the status endpoint until the job completes. + + Args: + req_id: The request ID to check + headers: Request headers + status_url: Status URL for checking job completion + polling_interval: Seconds between status checks + timeout: Request timeout in seconds + + Returns: + The final response or None on failure + """ + while True: + try: + async with aiohttp.ClientSession() as session: + async with session.get( + f"{status_url}/{req_id}", + headers=headers, + timeout=timeout + ) as response: + if response.status == 200: + # Job completed + logger.info(f"RFDiffusion job {req_id} completed") + return await response.json() + elif response.status == 202: + # Job still running + logger.debug(f"RFDiffusion job {req_id} still running, polling...") + await asyncio.sleep(polling_interval) + else: + logger.error(f"Error checking RFDiffusion job status: {response.status}") + text = await response.text() + logger.error(f"Response: {text}") + return None + except Exception as e: + logger.error(f"Error polling RFDiffusion job status: {e}") + return None \ No newline at end of file diff --git a/environments/hack0/protein_design_env/protein_env.py b/environments/hack0/protein_design_env/protein_env.py index 4a1e51bd..4cd708ea 100644 --- a/environments/hack0/protein_design_env/protein_env.py +++ b/environments/hack0/protein_design_env/protein_env.py @@ -3,9 +3,10 @@ import json import logging import os import random +import re import uuid from pathlib import Path -from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict +from typing import Dict, List, Any, Tuple, Optional, Union, TypedDict, Set import yaml import wandb # Add import for wandb @@ -89,6 +90,85 @@ def load_target_binder_pairs(dataset_name: str, target_col: str, binder_col: str return ds +def get_pdb_chain_details(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, Dict[str, int]], str]: + """ + Parses PDB content to extract detailed information for each chain. + + Returns: + A tuple containing: + - chain_details (Dict[str, Dict[str, int]]): + A dictionary where keys are chain IDs (e.g., "A"). + Each value is another dictionary: + { + "min_residue": int, # Smallest residue number found for this chain + "max_residue": int, # Largest residue number found for this chain + "length": int # Count of unique C-alpha atoms (residues) in this chain + } + - pdb_preview (str): A string preview of the PDB content. + """ + chain_info_temp: Dict[str, Dict[str, Union[Set[int], int]]] = {} # Stores residue numbers and CA count for each chain + atom_lines = [] + header_lines = [] + + # First pass: Collect all residue numbers and CA atoms per chain + for line in pdb_content.splitlines(): + if line.startswith("ATOM"): # Consider only ATOM records for canonical residues + atom_lines.append(line) + chain_id = line[21:22].strip() + if not chain_id: + chain_id = " " # Default for blank chain ID, consider how RFDiffusion handles this + + atom_name = line[12:16].strip() + + try: + residue_num = int(line[22:26].strip()) + + if chain_id not in chain_info_temp: + chain_info_temp[chain_id] = {"residues": set(), "ca_count": 0} + + chain_info_temp[chain_id]["residues"].add(residue_num) + if atom_name == "CA": + chain_info_temp[chain_id]["ca_count"] += 1 + except ValueError: + logger.warning(f"Could not parse residue number from PDB line: {line}") + continue + elif line.startswith("HEADER") or line.startswith("TITLE") or line.startswith("COMPND"): + header_lines.append(line) + + # Second pass: Calculate min, max, and length from collected data + chain_details: Dict[str, Dict[str, int]] = {} + for chain_id, data in chain_info_temp.items(): + if data["residues"]: # Only process if residues were found + min_res = min(data["residues"]) + max_res = max(data["residues"]) + # Length can be defined in two ways: + # 1. max_res - min_res + 1 (if contiguous numbering) + # 2. Count of unique residues (safer for gaps, but AF2 is usually contiguous) + # 3. Count of C-alpha atoms (good proxy for actual modeled residues) + # Let's use ca_count as it reflects actual modeled residues. + # If ca_count is 0 but residues were found (e.g. only HETATMs), this needs thought. + # For now, prioritizing ca_count. + length = data["ca_count"] if data["ca_count"] > 0 else len(data["residues"]) + + chain_details[chain_id] = { + "min_residue": min_res, + "max_residue": max_res, + "length": length + } + else: + logger.warning(f"Chain {chain_id} had no parseable ATOM residue numbers.") + + + # Construct PDB preview + preview_str_parts = header_lines[:min(len(header_lines), preview_lines // 2)] + remaining_preview_lines = preview_lines - len(preview_str_parts) + preview_str_parts.extend(atom_lines[:min(len(atom_lines), remaining_preview_lines)]) + pdb_preview = "\n".join(preview_str_parts) + if len(pdb_content.splitlines()) > preview_lines: + pdb_preview += "\n..." + + return chain_details, pdb_preview + def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10) -> Tuple[Dict[str, int], str]: chain_lengths = {} current_chain_id = None @@ -130,11 +210,10 @@ def get_pdb_chain_lengths_and_preview(pdb_content: str, preview_lines: int = 10) return chain_lengths, pdb_preview def construct_user_prompt(state: dict) -> str: # state is an item from self.episodes_state - internal_step = state.get("current_internal_step", 0) # Use internal step from our state + internal_step = state.get("current_internal_step", 0) target_sequence = state.get("target_sequence") user_prompt_str = "" - # Base prompt construction (your existing logic) if internal_step == 0: # Step 1: Predict Target Structure (AlphaFold2) user_prompt_str = ( f"The target protein sequence is: {target_sequence}. " @@ -142,52 +221,85 @@ def construct_user_prompt(state: dict) -> str: # state is an item from self.epis "You must provide the 'sequence' argument." ) elif internal_step == 1: # Step 2: Design Binder Backbone (RFDiffusion) - # Use the stored preview and chain info - target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") - chain_info = state.get("target_chain_info", {}) - chain_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in chain_info.items()]) - if not chain_info_str: chain_info_str = "Chain information not available." + target_pdb_preview = state.get("target_pdb_preview", "PDB preview not available.") # Can keep preview for general context + + # --- NEW CHAIN INFO FORMATTING --- + chain_details = state.get("target_chain_details", {}) # Get the new detailed info + if chain_details: + chain_info_parts = [] + for chain_id, details in chain_details.items(): + min_r = details.get('min_residue', 'N/A') + max_r = details.get('max_residue', 'N/A') + l = details.get('length', 'N/A') + chain_info_parts.append(f"Chain {chain_id} (Residues: {min_r}-{max_r}, Length: {l} amino acids)") + chain_info_str = "\n- ".join(chain_info_parts) + if chain_info_str: + chain_info_str = "- " + chain_info_str # Add leading bullet for the first item + else: + chain_info_str = "Chain information not available or PDB not yet processed." + # --- END NEW CHAIN INFO FORMATTING --- user_prompt_str = ( - f"The 3D structure of the target protein has been predicted. Target PDB preview:\n{target_pdb_preview}\n" - f"Target chain information: {chain_info_str}.\n" - "Now, design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. " - "You need to specify 'contigs'. Contigs define segments from the target PDB (e.g., 'A1-100' means residues 1-100 of target chain A) " - "and segments for the new binder (e.g., '/0 50-70' means generate a new chain of length 50 to 70 residues). " - "A full example for a 60-residue binder for Chain A of the target (if Chain A has 100 residues): 'A1-100/0 60'. " - "Ensure any target residue numbers in contigs are within the valid range for the respective chain. " - "Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring they exist on the target." + f"The 3D structure of the target protein has been predicted.\n" + # Optional: f"Target PDB preview:\n{target_pdb_preview}\n\n" + f"Target Protein Chain Details:\n{chain_info_str}\n\n" # Use the detailed chain info + "Your task is to design a binder backbone using the 'design_binder_backbone_rfdiffusion' tool. " + "You MUST specify 'contigs' for this tool. The 'contigs' string defines segments from the target PDB and segments for the new binder. " + "Examples:\n" + " - To use residues 10 through 100 of target chain A, and then diffuse a 60-residue binder: 'A10-100/0 60'\n" + " - To use chain B from residue 5 to 50, then diffuse a 30-residue binder, then use chain B from residue 60 to 100: 'B5-50/0 30 B60-100'\n" + "You MUST use the chain IDs and residue ranges exactly as provided in the 'Target Protein Chain Details' above. " + "Do not invent chains or residue numbers outside these specified ranges for the target segments. " + "For binder segments (e.g., '/0 60'), specify the desired length (e.g., 60).\n" + "Optionally, provide 'hotspot_residues' (e.g., ['A50', 'A52']), ensuring these residues exist on the target as per the details above." ) elif internal_step == 2: # Step 3: Design Binder Sequence (ProteinMPNN) - binder_pdb_preview = state.get("binder_pdb_preview", "Binder PDB preview not available.") - binder_chain_info = state.get("binder_chain_info", {}) # Info about the binder backbone itself - binder_info_str = ", ".join([f"Chain {cID} (length {length})" for cID, length in binder_chain_info.items()]) - if not binder_info_str: binder_info_str = "Binder chain information not available." + # Get detailed binder chain information using the get_pdb_chain_details function + binder_pdb_content = state.get("binder_backbone_pdb_content") + if binder_pdb_content: + binder_chain_details, binder_pdb_preview = get_pdb_chain_details(binder_pdb_content) + binder_chain_info_str = "\n- ".join([f"Chain {cID} (Residues: {d.get('min_residue','N/A')}-{d.get('max_residue','N/A')}, Length: {d.get('length','N/A')})" for cID, d in binder_chain_details.items()]) + if binder_chain_info_str: binder_chain_info_str = "- " + binder_chain_info_str + else: + binder_pdb_preview = "Binder PDB preview not available." + binder_chain_info_str = "Binder chain information not available." user_prompt_str = ( f"A binder backbone has been generated. Binder PDB preview:\n{binder_pdb_preview}\n" - f"Binder chain information: {binder_info_str}.\n" + f"Binder chain information:\n{binder_chain_info_str}.\n" "Now, design an optimal amino acid sequence for this binder backbone using the 'design_binder_sequence_proteinmpnn' tool. " "You can optionally specify 'sampling_temp' (e.g., [0.1, 0.2])." ) elif internal_step == 3: # Step 4: Evaluate Complex (AlphaFold2-Multimer) - designed_binder_seq = state.get("designed_binder_sequence", "Not yet available") + designed_binder_seq_data = state.get("designed_binder_sequence") # This is List[str] + + binder_display_str = "Not available" + if isinstance(designed_binder_seq_data, list) and designed_binder_seq_data: + if len(designed_binder_seq_data) == 1: + binder_display_str = designed_binder_seq_data[0] + else: + binder_display_str = f"{len(designed_binder_seq_data)} chains: " + \ + ", ".join([f"Chain {i+1} ({len(s)} aa): {s[:20]}..." + for i, s in enumerate(designed_binder_seq_data)]) + elif isinstance(designed_binder_seq_data, str): # Should not happen with new PMPNN parsing + binder_display_str = designed_binder_seq_data + user_prompt_str = ( - f"A binder sequence has been designed: {designed_binder_seq}. " - f"The original target sequence was: {target_sequence}.\n" # Remind LLM of original target - "Finally, evaluate the binding complex of the original target protein and this designed binder using the " + f"A binder has been designed. Designed binder sequence(s): {binder_display_str}. " + f"The original target sequence was: {target_sequence[:60]}...\n" + "Finally, evaluate the binding complex of the original target protein and ALL chains of this designed binder using the " "'evaluate_binder_complex_alphafold2_multimer' tool. " "You can optionally specify 'relax_prediction' (default is True)." ) else: # Workflow complete or error user_prompt_str = "The protein design workflow is complete. No further actions required by you for this item. If successful, the key metric was the pLDDT of the complex." - # ***** ADD RETRY PREFIX IF APPLICABLE ***** - if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: # For all steps + # Retry logic should remain the same: + if state.get("retry_count_this_internal_step", 0) > 0 and internal_step < 4: retry_prefix = "Your previous attempt at this step was not successful. " if state.get("previous_tool_error_message"): retry_prefix += f"Details: {state['previous_tool_error_message']}. " - retry_prefix += "Please review the requirements and try again to correctly use the expected tool.\n\n" + retry_prefix += "Please review the requirements and PDB details carefully and try again to correctly use the expected tool.\n\n" user_prompt_str = retry_prefix + user_prompt_str return user_prompt_str @@ -204,7 +316,7 @@ class BinderBenchConfig(BaseEnvConfig): nim_api_base_url: str = Field("https://health.api.nvidia.com/v1", description="NIM API base URL") api_timeout: int = Field(1800, description="Timeout for NIM API calls") # Increased default polling_interval: int = Field(30, description="Polling interval for NIM jobs") # Increased default - output_dir: str = Field("binder_outputs", description="Directory to save PDBs, etc.") + output_dir: str = Field(default=str(Path(__file__).parent / "outputs"), description="Directory to save PDBs, etc.") debug_protein_design_calls: bool = Field(False, description="Enable debug mode for NIM protein API calls, returning mock data.") max_retries_per_internal_step: int = Field(100, description="Max retries for a failed tool call within a workflow step (0 means no retries).") # Default to 1 retry (2 attempts total) # Dataset specific @@ -349,8 +461,8 @@ class BinderBenchEnv(BaseEnv): # Create a dummy PDB content if file not found to prevent downstream errors, but log severe warning pdb_content = "HEADER DUMMY PDB FOR DEBUG - TARGET.PDB NOT FOUND\nATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N\nTER\nEND\n" workflow_state["target_pdb_content"] = pdb_content - chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper - workflow_state["target_chain_info"] = chain_lengths + chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function + workflow_state["target_chain_details"] = chain_details # Store detailed info workflow_state["target_pdb_preview"] = pdb_preview workflow_state["target_structure_predicted"] = True # Mark as "predicted" for workflow to proceed return {"success": False, "error": f"Debug mode error: {fixed_pdb_path} not found.", "target_pdb_preview": pdb_preview} @@ -360,15 +472,15 @@ class BinderBenchEnv(BaseEnv): pdb_content = f.read() workflow_state["target_pdb_content"] = pdb_content - chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) # Use your helper - workflow_state["target_chain_info"] = chain_lengths + chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function + workflow_state["target_chain_details"] = chain_details # Store detailed info workflow_state["target_pdb_preview"] = pdb_preview workflow_state["target_structure_predicted"] = True # Optionally, save a copy of this mock PDB to the usual output location for consistency debug_output_pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2_DEBUG.pdb" with open(debug_output_pdb_path, "w") as f: f.write(pdb_content) - logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain info: {chain_lengths}") + logger.info(f"DEBUG MODE: Used fixed PDB from {fixed_pdb_path}. Copied to {debug_output_pdb_path}. Chain details: {chain_details}") return {"success": True, "message": "DEBUG MODE: Used fixed PDB for AlphaFold2.", "target_pdb_preview": pdb_preview} except Exception as e: @@ -393,13 +505,13 @@ class BinderBenchEnv(BaseEnv): if api_result and isinstance(api_result, list) and api_result[0]: pdb_content = api_result[0] workflow_state["target_pdb_content"] = pdb_content - chain_lengths, pdb_preview = get_pdb_chain_lengths_and_preview(pdb_content) - workflow_state["target_chain_info"] = chain_lengths + chain_details, pdb_preview = get_pdb_chain_details(pdb_content) # Use the new function + workflow_state["target_chain_details"] = chain_details # Store detailed info workflow_state["target_pdb_preview"] = pdb_preview workflow_state["target_structure_predicted"] = True pdb_path = self.output_dir / f"target_{item_id}_s{workflow_state['current_internal_step']}_af2.pdb" with open(pdb_path, "w") as f: f.write(pdb_content) - logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain info: {chain_lengths}") + logger.info(f"Workflow {item_id}: AlphaFold2 PDB saved to {pdb_path}. Chain details: {chain_details}") return {"success": True, "message": "AlphaFold2 prediction complete.", "target_pdb_preview": pdb_preview} else: logger.error(f"Workflow {item_id}: AlphaFold2 call failed or returned unexpected data: {api_result}") @@ -434,47 +546,125 @@ class BinderBenchEnv(BaseEnv): async def _run_nim_proteinmpnn(self, args: Dict, workflow_state: Dict) -> Dict: binder_pdb = workflow_state.get("binder_backbone_pdb_content") - if not binder_pdb: return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."} + if not binder_pdb: + return {"success": False, "error": "Binder backbone PDB not found for ProteinMPNN."} - sampling_temp = args.get("sampling_temp", [0.1]) # Default if not provided + sampling_temp_list = args.get("sampling_temp", [0.1]) item_id = workflow_state["item_id"] + api_result = await call_proteinmpnn( input_pdb=binder_pdb, api_key=self.config.nim_api_key, - sampling_temp=sampling_temp, + sampling_temp=sampling_temp_list, timeout=self.config.api_timeout, polling_interval=self.config.polling_interval - # Add other PMPNN specific params ) - if api_result and "mfasta" in api_result: - fasta_content = api_result["mfasta"] - designed_sequence = "" - for line in fasta_content.splitlines(): - if not line.startswith(">") and line.strip(): - designed_sequence = line.strip() - break - if not designed_sequence: - return {"success": False, "error": "Could not parse sequence from ProteinMPNN FASTA."} - workflow_state["designed_binder_sequence"] = designed_sequence - workflow_state["binder_sequence_designed"] = True - fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta" - with open(fasta_path, "w") as f: f.write(fasta_content) - logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}") - # NO LONGER INCREMENT current_internal_step HERE - collect_trajectories will handle this - return {"success": True, "message": "ProteinMPNN complete.", "designed_binder_sequence": designed_sequence} - else: + if not (api_result and "mfasta" in api_result): logger.error(f"Workflow {item_id}: ProteinMPNN call failed or returned unexpected data: {api_result}") - return {"success": False, "error": "ProteinMPNN failed."} + return {"success": False, "error": "ProteinMPNN call failed or no mfasta in result."} + + fasta_content = api_result["mfasta"] + logger.info(f"CRITICAL_DEBUG: ProteinMPNN raw mfasta output for item {item_id}:\n{fasta_content}") + + # --- FASTA Parsing Logic to find best sequence by global_score --- + entries: List[Tuple[float, str, str]] = [] # (global_score, header, sequence_line) + current_header = None + current_sequence_parts: List[str] = [] + + for line_content in fasta_content.splitlines(): + line = line_content.strip() + if not line: continue + + if line.startswith(">"): + if current_header and current_sequence_parts: # Process previous entry + full_sequence_line = "".join(current_sequence_parts) + score_match = re.search(r"global_score=([-\d.]+)", current_header) + global_score = float(score_match.group(1)) if score_match else -float('inf') + entries.append((global_score, current_header, full_sequence_line)) + current_header = line + current_sequence_parts = [] + else: + current_sequence_parts.append(line) + + if current_header and current_sequence_parts: # Process the last entry + full_sequence_line = "".join(current_sequence_parts) + score_match = re.search(r"global_score=([-\d.]+)", current_header) + global_score = float(score_match.group(1)) if score_match else -float('inf') + entries.append((global_score, current_header, full_sequence_line)) + + if not entries: + logger.error(f"Workflow {item_id}: No sequences found in ProteinMPNN mfasta output.") + return {"success": False, "error": "No sequences parsed from ProteinMPNN mfasta."} + + # Sort by global_score (descending) and select the best + entries.sort(key=lambda x: x[0], reverse=True) + best_global_score, best_header, best_full_sequence_line = entries[0] + + logger.info(f"Workflow {item_id}: Best PMPNN sequence chosen (global_score={best_global_score:.4f}) from header: '{best_header}'") + logger.info(f"Workflow {item_id}: Corresponding sequence line: '{best_full_sequence_line}'") + + # Split the selected sequence line by '/' to handle potential chainbreaks + parsed_binder_chains: List[str] = [ + seq_part.strip() for seq_part in best_full_sequence_line.split('/') if seq_part.strip() + ] + + if not parsed_binder_chains: + error_msg = f"Splitting best PMPNN sequence ('{best_full_sequence_line}') by '/' yielded no valid chains." + logger.error(f"Workflow {item_id}: {error_msg}") + return {"success": False, "error": error_msg} + + # Validate each parsed chain (ensure they are valid protein sequences) + for seq_idx, seq_part in enumerate(parsed_binder_chains): + if not (seq_part and seq_part.isalpha() and seq_part.isupper()): + error_msg = f"Parsed binder chain {seq_idx+1} ('{seq_part[:30]}...') contains invalid characters or is empty." + logger.error(f"Workflow {item_id}: {error_msg}") + return {"success": False, "error": error_msg} + + workflow_state["designed_binder_sequence"] = parsed_binder_chains # Store as List[str] + workflow_state["binder_sequence_designed"] = True + + fasta_path = self.output_dir / f"binder_sequence_{item_id}_s{workflow_state['current_internal_step']}_pmpnn.fasta" + with open(fasta_path, "w") as f: f.write(fasta_content) # Save original full FASTA + logger.info(f"Workflow {item_id}: ProteinMPNN FASTA saved to {fasta_path}. Selected binder chains: {parsed_binder_chains}") + + preview = parsed_binder_chains[0][:60] + "..." if len(parsed_binder_chains[0]) > 60 else parsed_binder_chains[0] + if len(parsed_binder_chains) > 1: + preview += f" (+ {len(parsed_binder_chains)-1} more chain(s))" + + return { + "success": True, + "message": f"ProteinMPNN complete. Selected best (global_score={best_global_score:.4f}).", + "designed_binder_sequence_list": parsed_binder_chains, + "designed_binder_sequence_preview": preview + } async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict: target_seq = workflow_state.get("target_sequence") - binder_seq = workflow_state.get("designed_binder_sequence") - if not target_seq or not binder_seq: - return {"success": False, "error": "Missing target or binder sequence for AlphaFold2-Multimer."} + binder_seq_data = workflow_state.get("designed_binder_sequence") + + if not target_seq: + return {"success": False, "error": "Missing target sequence for AlphaFold2-Multimer."} + + if not binder_seq_data: + return {"success": False, "error": "Missing binder sequence for AlphaFold2-Multimer."} + + # Handle binder_seq_data which could now be either a List[str] or a single string (for backward compatibility) + binder_sequences = [] + if isinstance(binder_seq_data, list): + binder_sequences = binder_seq_data + elif isinstance(binder_seq_data, str): + binder_sequences = [binder_seq_data] # Wrap in list + else: + return {"success": False, "error": f"Unexpected type for binder sequence: {type(binder_seq_data)}"} + + if not binder_sequences: + return {"success": False, "error": "Empty binder sequence list for AlphaFold2-Multimer."} relax = args.get("relax_prediction", True) item_id = workflow_state["item_id"] - logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and binder (len {len(binder_seq)}).") + # Log all sequences for debugging + total_binder_length = sum(len(seq) for seq in binder_sequences) + logger.info(f"Workflow {item_id}: Running AlphaFold2-Multimer with target (len {len(target_seq)}) and {len(binder_sequences)} binder chain(s) (total len {total_binder_length}).") # Check if in debug mode if self.config.debug_protein_design_calls: @@ -515,12 +705,18 @@ class BinderBenchEnv(BaseEnv): } # Non-debug mode: proceed with actual API call + # Create a list with target sequence as first element, followed by all binder sequences + all_sequences = [target_seq] + binder_sequences + + logger.info(f"Workflow {item_id}: Calling AlphaFold2-Multimer with {len(all_sequences)} sequences: " + f"1 target ({len(target_seq)} aa) + {len(binder_sequences)} binder chains.") + api_result = await call_alphafold2_multimer( - sequences=[target_seq, binder_seq], + sequences=all_sequences, # Pass all sequences as a flat list: [target_seq, binder_seq1, binder_seq2, ...] api_key=self.config.nim_api_key, relax_prediction=relax, - timeout=self.config.api_timeout, # Pass configured timeout - polling_interval=self.config.polling_interval # Pass configured polling interval + timeout=self.config.api_timeout, + polling_interval=self.config.polling_interval ) if api_result and api_result.get("structures") and len(api_result["structures"]) > 0: @@ -624,6 +820,7 @@ class BinderBenchEnv(BaseEnv): "target_sequence": target_sequence, "ground_truth_binder_sequence": ground_truth_binder, # Store for final evaluation "target_pdb_content": None, + "target_chain_details": None, # Store detailed chain information "binder_backbone_pdb_content": None, "designed_binder_sequence": None, "complex_pdb_content_path": None, # Path to AF2-Multimer output @@ -758,7 +955,7 @@ class BinderBenchEnv(BaseEnv): workflow_state["retry_count_this_internal_step"] = 0 # Reset for new step workflow_state["previous_tool_error_message"] = None else: # Tool call failed or was incorrect - if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2 + if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3 workflow_state["retry_count_this_internal_step"] += 1 if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step: logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Max retries ({self.config.max_retries_per_internal_step}) reached. Terminating workflow for this item.") @@ -767,8 +964,8 @@ class BinderBenchEnv(BaseEnv): else: logger.info(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failed, attempt {workflow_state['retry_count_this_internal_step']}. Retrying same step.") # Loop continues, construct_user_prompt will use retry info - else: # Failure at step 3 (AF2M) is terminal for the workflow - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at critical AF2-Multimer step. Terminating workflow.") + else: # Should never reach here with MAX_INTERNAL_STEPS = 4, but keeping for safety + logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']}: Failure at non-retryable step. Terminating workflow.") workflow_state["workflow_complete_flag"] = True break # Exit the internal while loop @@ -778,24 +975,101 @@ class BinderBenchEnv(BaseEnv): # No break here, loop condition will handle it # After the internal while loop (for process mode) - if not all_turns_data_for_jsonl: return None, [] - last_turn_data = all_turns_data_for_jsonl[-1] - aggregated_messages = [turn_data["messages_this_turn"] for turn_data in all_turns_data_for_jsonl] - aggregated_overrides = [turn_data["overrides_this_turn"] for turn_data in all_turns_data_for_jsonl] - final_reward_for_group = workflow_state.get("cumulative_reward", 0.0) - if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"): - final_reward_for_group = last_turn_data["overrides_this_turn"].get("overall_reward", 0.0) + if not all_turns_data_for_jsonl: + logger.warning(f"Workflow {item_id} in process mode: No turn data collected.") + return None, [] + # --- Start of Fix for jsonl2html --- + html_compatible_messages: List[str] = [] + html_compatible_scores: List[float] = [] + # `overrides_for_jsonl` will store the detailed scoring dict for each turn, + # matching the structure of `html_compatible_messages` and `html_compatible_scores`. + overrides_for_jsonl: List[Dict[str, Any]] = [] + + + for turn_idx, turn_data in enumerate(all_turns_data_for_jsonl): + # Format messages for this turn into a single readable string + turn_str_parts = [f"--- Workflow {item_id} - Turn {turn_idx + 1} ---"] + if turn_data.get("messages_this_turn"): + for msg_obj in turn_data["messages_this_turn"]: + content_str = str(msg_obj.get("content", "[No Content]")) + if msg_obj.get("tool_calls"): + try: + tool_calls_str = json.dumps(msg_obj.get("tool_calls"), indent=2) + content_str += f"\nTool Calls:\n{tool_calls_str}" + except TypeError: # Handle non-serializable content if any + content_str += f"\nTool Calls: [Error serializing tool_calls]" + turn_str_parts.append(f"**{msg_obj.get('role', 'unknown').upper()}**: {content_str}") + else: + turn_str_parts.append("No messages recorded for this turn.") + + html_compatible_messages.append("\n\n".join(turn_str_parts)) + + # Get the score for this specific turn + turn_score = turn_data.get("overrides_this_turn", {}).get("overall_reward", 0.0) + html_compatible_scores.append(turn_score) + + # Add the detailed scoring dictionary for this turn + overrides_for_jsonl.append(turn_data.get("overrides_this_turn", {})) + + + final_workflow_reward = workflow_state.get("cumulative_reward", 0.0) + # If the complex was evaluated successfully, the last turn's reward is the final one. + if workflow_state.get("complex_evaluated") and workflow_state.get("last_tool_success"): + final_workflow_reward = all_turns_data_for_jsonl[-1].get("overrides_this_turn", {}).get("overall_reward", 0.0) + + # For the ScoredDataGroup that will be handled by BaseEnv + # We need tokens and masks for each "message" (turn) if we want BaseEnv to consider it valid + # For simplicity, we can just repeat the last turn's tokens/masks, or use placeholders + # if actual per-turn tokens aren't critical for the JSONL's main purpose (which is visualization via messages/scores). + # Let's create placeholder tokens/masks if full history isn't needed by the trainer for process_mode. + # Or, better, store actual tokens for each turn if available. + + all_tokens_per_turn = [turn_data["tokens_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("tokens_this_turn")] + all_masks_per_turn = [turn_data["masks_this_turn"] for turn_data in all_turns_data_for_jsonl if turn_data.get("masks_this_turn")] + + # Ensure all_tokens_per_turn and all_masks_per_turn have same length as html_compatible_messages + # If some turns didn't produce tokens (e.g. error), we might need to pad or handle. + # For now, assuming all_turns_data_for_jsonl consistently has tokens/masks for each entry that contributes to html_compatible_messages. + if len(all_tokens_per_turn) != len(html_compatible_messages): + logger.error(f"CRITICAL: Mismatch between tokenized turns ({len(all_tokens_per_turn)}) and HTML messages ({len(html_compatible_messages)}). JSONL will be problematic.") + # Fallback: repeat last turn's tokens if necessary, though this isn't ideal. + if all_turns_data_for_jsonl and all_tokens_per_turn: + last_tokens = all_tokens_per_turn[-1] + last_masks = all_masks_per_turn[-1] + all_tokens_per_turn = [last_tokens] * len(html_compatible_messages) + all_masks_per_turn = [last_masks] * len(html_compatible_messages) + else: # No token data at all + all_tokens_per_turn = [[] for _ in html_compatible_messages] + all_masks_per_turn = [[] for _ in html_compatible_messages] + + # This is the ScoredDataGroup that will be written to JSONL by BaseEnv process_mode_scored_data = ScoredDataGroup( - tokens=[last_turn_data["tokens_this_turn"]], - masks=[last_turn_data["masks_this_turn"]], - scores=[final_reward_for_group], - messages=aggregated_messages if self.config.include_messages else None, - overrides=aggregated_overrides, - group_overrides={"group_size": 1} # ***** THIS IS THE DEFINITIVE FIX FOR THIS PART ***** + tokens=all_tokens_per_turn, # List of token lists, one for each turn + masks=all_masks_per_turn, # List of mask lists, one for each turn + + # These are critical for jsonl2html + messages=html_compatible_messages, # List of strings, one per turn + scores=html_compatible_scores, # List of floats, one per turn + + # Store detailed overrides per turn, matching the length of messages/scores + overrides=overrides_for_jsonl, + + group_overrides={ + "group_size": len(html_compatible_messages), # Effective group size is number of turns + "item_id": item_id, + "is_process_mode_full_workflow": True, + "final_score_for_workflow": final_workflow_reward, # Store the overall workflow score here + "target_sequence": workflow_state.get("target_sequence", "N/A"), + "designed_binder_sequence": workflow_state.get("designed_binder_sequence", "N/A"), + "final_plddt": workflow_state.get("af2_multimer_plddt", 0.0) + } ) - # Log completed workflow for WandB before adding to metrics - await self.add_rollouts_for_wandb(workflow_state) + # --- End of Fix for jsonl2html --- + + # Log detailed workflow state to WandB (this call should use workflow_state directly) + await self.add_rollouts_for_wandb(data_for_log=workflow_state.copy()) # Keep passing workflow_state for detailed wandb logging + self.completed_episode_metrics.append(workflow_state.copy()) if item_id in self.episodes_state: del self.episodes_state[item_id] return process_mode_scored_data, [] @@ -868,25 +1142,25 @@ class BinderBenchEnv(BaseEnv): workflow_state["retry_count_this_internal_step"] = 0 workflow_state["previous_tool_error_message"] = None else: # Tool failed or was incorrect - if workflow_state["current_internal_step"] < 3: # Only retry for steps 0-2 + if workflow_state["current_internal_step"] <= 3: # Retry for steps 0, 1, 2, AND 3 workflow_state["retry_count_this_internal_step"] += 1 if workflow_state["retry_count_this_internal_step"] > self.config.max_retries_per_internal_step: logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Max retries reached. Terminating.") workflow_state["workflow_complete_flag"] = True # else: it will be added to backlog below to retry - else: # Failure at step 3 (AF2M) or other non-retryable step - logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at critical/non-retryable step. Terminating.") + else: # Failure at non-retryable step (should never reach here with MAX_INTERNAL_STEPS = 4) + logger.warning(f"Workflow {item_id}, Step {workflow_state['current_internal_step']} (Serve Mode): Failure at non-retryable step. Terminating.") workflow_state["workflow_complete_flag"] = True if workflow_state["current_internal_step"] < 4 and not workflow_state.get("workflow_complete_flag"): # Add to backlog if: # 1. Last tool was successful (to move to next step) # OR - # 2. Last tool failed, current step is < 3, and we haven't hit max retries (to retry current step) + # 2. Last tool failed, current step is <= 3, and we haven't hit max retries (to retry current step) should_add_to_backlog = False if workflow_state["last_tool_success"]: should_add_to_backlog = True - elif workflow_state["current_internal_step"] < 3 and \ + elif workflow_state["current_internal_step"] <= 3 and \ workflow_state["retry_count_this_internal_step"] <= self.config.max_retries_per_internal_step: should_add_to_backlog = True @@ -897,10 +1171,16 @@ class BinderBenchEnv(BaseEnv): logger.info(f"Workflow for {item_id} (Serve Mode) not added to backlog and marked complete. Internal step: {workflow_state['current_internal_step']}") if workflow_state.get("workflow_complete_flag"): # If flag was set either by reaching step 4 or by retry logic - # Log completed workflow for WandB before adding to metrics - await self.add_rollouts_for_wandb(workflow_state) - self.completed_episode_metrics.append(workflow_state.copy()) - if item_id in self.episodes_state: del self.episodes_state[item_id] + # For completed workflows in serve mode, use direct logging with workflow_state + # before it gets deleted + if item_id in self.episodes_state: + # Use direct workflow_state logging for maximum detail + await self.add_rollouts_for_wandb(data_for_log=self.episodes_state[item_id].copy()) + self.completed_episode_metrics.append(self.episodes_state[item_id].copy()) + del self.episodes_state[item_id] + # Note: We don't need to call add_rollouts_for_wandb with scored_data_serve here + # BaseEnv.handle_send_to_api will call it automatically with the scored_data_serve + # that we return return scored_data_serve, backlog_items_serve @@ -968,7 +1248,7 @@ class BinderBenchEnv(BaseEnv): logger.info(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): pLDDT={plddt:.2f}. Reward: {detailed_scores['overall_reward']:.2f}") else: - detailed_scores["overall_reward"] = -0.5 + detailed_scores["overall_reward"] = 0.0 logger.warning(f"Workflow {workflow_state['item_id']}, Step {internal_step} (AF2-Multimer): Evaluation failed or wrong tool. Reward: -0.5. Last tool success: {last_tool_success}, Called: {called_tool_name}") else: @@ -1046,24 +1326,62 @@ class BinderBenchEnv(BaseEnv): # Given it's populated by collect_trajectories, clearing seems appropriate for periodic eval. self.completed_episode_metrics.clear() - async def add_rollouts_for_wandb(self, workflow_state: Dict): - """Adds a completed workflow summary to the wandb rollout buffer.""" + async def add_rollouts_for_wandb(self, + scored_data_group: ScoredDataGroup = None, # From BaseEnv + item_id: Item = None, # From BaseEnv + data_for_log: Dict = None): # Our custom param for direct workflow_state logging + """Adds a workflow summary to the wandb rollout buffer. + + This method has two modes of operation: + 1. Direct logging with workflow_state (preferred for detailed logging): + - Called from within collect_trajectories with data_for_log=workflow_state.copy() + - This provides maximum detail for logging + + 2. BaseEnv compatibility mode: + - Called from BaseEnv.handle_send_to_api with scored_data_group and item_id + - Used automatically by the framework + - May have less detail if workflow_state was already deleted + + Args: + scored_data_group: The ScoredDataGroup containing token, mask, and score data (from BaseEnv) + item_id: The item identifier, which is the key to our episodes_state (from BaseEnv) + data_for_log: Direct workflow state to log (our custom parameter for direct logging) + """ if not self.config.use_wandb or not hasattr(self, "rollouts_for_wandb"): - # Ensure rollouts_for_wandb exists, BaseEnv usually inits it in its __init__ - # but if not, init here. + # Ensure rollouts_for_wandb exists if not hasattr(self, "rollouts_for_wandb"): self.rollouts_for_wandb = [] + # Determine the workflow state to use + workflow_state = None + + # Case 1: Direct workflow state provided (most detailed) + if data_for_log is not None and isinstance(data_for_log, dict): + workflow_state = data_for_log + # Extract item_id from data_for_log if needed + if item_id is None and "item_id" in workflow_state: + item_id = workflow_state["item_id"] + + # Case 2: Try to get workflow_state from episodes_state (if not already deleted) + elif item_id is not None and item_id in self.episodes_state: + workflow_state = self.episodes_state[item_id] + + # Case 3: No usable state - early return with a debug log (not warning) + # This happens in BaseEnv.handle_send_to_api after workflow is already completed + if workflow_state is None: + # This is expected in BaseEnv's call after workflow_state is deleted, so use debug level + logger.debug(f"No workflow_state available for WandB logging (item_id={item_id})") + return + # Customize what you want to see in the WandB table for a completed workflow # Handle cases where values might be None - item_id = workflow_state.get("item_id", "unknown-id") target_seq = workflow_state.get("target_sequence", "N/A") - + # Handle designed_binder which might be None designed_binder = workflow_state.get("designed_binder_sequence", "N/A") if designed_binder is None: designed_binder = "N/A" - + plddt = workflow_state.get("af2_multimer_plddt", 0.0) iptm = workflow_state.get("af2_multimer_iptm", 0.0) # Even if 0, log it cumulative_reward = workflow_state.get("cumulative_reward", 0.0) @@ -1084,15 +1402,20 @@ class BinderBenchEnv(BaseEnv): # Safely truncate strings target_preview = target_seq[:30] + "..." if isinstance(target_seq, str) and len(target_seq) > 30 else target_seq - + if designed_binder == "N/A" or designed_binder is None: binder_preview = "N/A" else: binder_preview = designed_binder[:30] + "..." if len(str(designed_binder)) > 30 else designed_binder - + + # Use item_id from workflow_state if still None + if item_id is None: + item_id = workflow_state.get("item_id", "unknown-id") + + # Add to rollouts buffer self.rollouts_for_wandb.append( ( # This tuple structure will be used by create_rollout_table - item_id, + str(item_id), # Ensure item_id is a string target_preview, binder_preview, f"{plddt:.2f}", diff --git a/environments/hack0/protein_design_env/utils/__init__.py b/environments/hack0/protein_design_env/utils/__init__.py new file mode 100644 index 00000000..dc858459 --- /dev/null +++ b/environments/hack0/protein_design_env/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for the protein design environment.""" \ No newline at end of file diff --git a/environments/hack0/protein_design_env/utils/api_utils.py b/environments/hack0/protein_design_env/utils/api_utils.py new file mode 100644 index 00000000..4155fb85 --- /dev/null +++ b/environments/hack0/protein_design_env/utils/api_utils.py @@ -0,0 +1,28 @@ +"""API utility functions for the protein design environment.""" + +import os +import logging +import yaml +from pathlib import Path +from typing import Optional +from dotenv import load_dotenv + +# Load environment variables from .env file +load_dotenv() + +logger = logging.getLogger(__name__) + +def load_api_key() -> Optional[str]: + """ + Load the NVIDIA NIM API key from environment variables. + + Returns: + The API key from environment variables, or None if not found + """ + api_key = os.environ.get("NVIDIA_NIM_API_KEY") + if not api_key: + logger.error("NVIDIA_NIM_API_KEY not found in environment variables. " + "Please set it in your .env file.") + return None + + return api_key \ No newline at end of file