mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-27 17:23:08 +00:00
refactor file saving from alphafold2_multimer to tool_executor
This commit is contained in:
parent
227e594ebf
commit
02585947e4
2 changed files with 162 additions and 156 deletions
|
|
@ -4,13 +4,14 @@ import aiohttp
|
|||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from pathlib import Path # Keep Path for type hinting if used elsewhere, but not for output_dir here
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_URL = "https://health.api.nvidia.com/v1/biology/deepmind/alphafold2-multimer"
|
||||
DEFAULT_STATUS_URL = "https://health.api.nvidia.com/v1/status"
|
||||
|
||||
# _split_pdb_content remains the same
|
||||
def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
||||
"""
|
||||
Splits a string containing concatenated PDB file contents.
|
||||
|
|
@ -23,7 +24,7 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
|
||||
for line in concatenated_pdb_str.splitlines(keepends=True):
|
||||
current_pdb_lines.append(line)
|
||||
if line.startswith("ENDMDL") or line.startswith("END "):
|
||||
if line.startswith("ENDMDL") or line.startswith("END "): # Fixed: added space for "END "
|
||||
pdbs.append("".join(current_pdb_lines).strip())
|
||||
current_pdb_lines = []
|
||||
|
||||
|
|
@ -35,17 +36,8 @@ def _split_pdb_content(concatenated_pdb_str: str) -> List[str]:
|
|||
return [pdb for pdb in pdbs if pdb]
|
||||
|
||||
|
||||
# calculate_plddt_from_pdb_string remains the same
|
||||
def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float], Dict[str, List[float]]]:
|
||||
"""
|
||||
Calculates the average pLDDT score from a PDB string for C-alpha atoms.
|
||||
Also returns a list of all C-alpha pLDDTs and a dictionary of pLDDTs per chain.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- average_plddt (float): Average pLDDT over all C-alpha atoms.
|
||||
- plddt_scores_per_ca (List[float]): List of pLDDTs for each C-alpha atom.
|
||||
- plddt_scores_per_chain (Dict[str, List[float]]): Dict mapping chain ID to its C-alpha pLDDTs.
|
||||
"""
|
||||
total_plddt = 0.0
|
||||
ca_atom_count = 0
|
||||
plddt_scores_per_ca: List[float] = []
|
||||
|
|
@ -78,46 +70,25 @@ def calculate_plddt_from_pdb_string(pdb_string: str) -> Tuple[float, List[float]
|
|||
return average_plddt, plddt_scores_per_ca, plddt_scores_per_chain
|
||||
|
||||
async def _process_pdb_and_scores_from_api(
|
||||
pdb_contents: List[str],
|
||||
pdb_contents: List[str], # This is the list of PDB strings from the API
|
||||
job_id: str,
|
||||
api_response_json: Optional[Dict[str, Any]] = None,
|
||||
output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
api_response_json: Optional[Dict[str, Any]] = None # Not currently used, but kept for consistency
|
||||
) -> Optional[Dict[str, Any]]: # Return structure: {"structures": [...]}
|
||||
"""
|
||||
Processes a list of PDB strings received from the API JSON response.
|
||||
- The API responds with a direct list of PDB strings, not a nested JSON structure
|
||||
- This function saves each PDB to disk and calculates pLDDT scores
|
||||
- The api_response_json parameter is for potential future use if the API adds metadata
|
||||
- The output_dir parameter allows for customizing where files are saved
|
||||
Processes a list of PDB strings received from the API.
|
||||
- Calculates pLDDT scores for each PDB string.
|
||||
- Does NOT save files to disk.
|
||||
- Returns a dictionary containing a list of structures, each with its PDB content and scores.
|
||||
"""
|
||||
if output_dir is None:
|
||||
# Default behavior if no output dir is provided
|
||||
output_dir_name = f"alphafold2_multimer_output_{job_id}_results"
|
||||
output_dir = Path(f"./{output_dir_name}")
|
||||
|
||||
# Make sure the directory exists
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Saving AlphaFold2-Multimer results for job {job_id} to directory: {output_dir}")
|
||||
|
||||
results: Dict[str, Any] = {
|
||||
"output_directory": str(output_dir),
|
||||
"structures": [],
|
||||
"ptm_score": None,
|
||||
"iptm_score": None,
|
||||
"structures": []
|
||||
}
|
||||
|
||||
if isinstance(api_response_json, dict):
|
||||
logger.info(f"Attempting to extract additional scores from API JSON response for job {job_id}.")
|
||||
results["ptm_score"] = api_response_json.get("ptm")
|
||||
results["iptm_score"] = api_response_json.get("iptm")
|
||||
if results["ptm_score"] is not None or results["iptm_score"] is not None:
|
||||
logger.info(f"Extracted from API JSON: pTM={results['ptm_score']}, ipTM={results['iptm_score']}")
|
||||
else:
|
||||
logger.info(f"No additional API JSON dictionary provided or it's not a dict for job {job_id}; ptm/iptm scores will be None unless found elsewhere.")
|
||||
|
||||
if not pdb_contents or not isinstance(pdb_contents, list) or not all(isinstance(s, str) for s in pdb_contents):
|
||||
logger.warning(f"No valid PDB content strings provided for job {job_id}.")
|
||||
return results # Return with empty structures list
|
||||
# Return a structure indicating failure or empty results, consistent with how call_alphafold2_multimer handles errors
|
||||
return {"success": False, "error": "No valid PDB content strings from API.", "structures": []}
|
||||
|
||||
|
||||
logger.info(f"Processing {len(pdb_contents)} PDB structure(s) for job {job_id}.")
|
||||
|
||||
|
|
@ -127,8 +98,8 @@ async def _process_pdb_and_scores_from_api(
|
|||
continue
|
||||
|
||||
structure_data: Dict[str, Any] = {
|
||||
"model_index": i,
|
||||
"pdb_content": pdb_str
|
||||
"model_index": i, # Keep model_index for identification
|
||||
"pdb_content": pdb_str # Store the actual PDB content
|
||||
}
|
||||
|
||||
avg_plddt, plddts_per_ca_residue, plddts_by_chain = calculate_plddt_from_pdb_string(pdb_str)
|
||||
|
|
@ -144,26 +115,21 @@ async def _process_pdb_and_scores_from_api(
|
|||
else:
|
||||
avg_plddt_per_chain[chain_id] = 0.0
|
||||
structure_data["average_plddt_per_chain"] = avg_plddt_per_chain
|
||||
|
||||
pdb_file_name_stem = f"alphafold2_multimer_output_{job_id}"
|
||||
rank_suffix = f"_model_{i+1}"
|
||||
pdb_file_path = output_dir / f"{pdb_file_name_stem}{rank_suffix}.pdb"
|
||||
|
||||
try:
|
||||
with open(pdb_file_path, "w", encoding='utf-8') as f_pdb:
|
||||
f_pdb.write(pdb_str)
|
||||
structure_data["saved_pdb_path"] = str(pdb_file_path)
|
||||
logger.info(f"Saved PDB model {i+1} for job {job_id} to {pdb_file_path} with overall avg_pLDDT: {avg_plddt:.2f}")
|
||||
except Exception as e_write:
|
||||
logger.error(f"Failed to write PDB file {pdb_file_path} for job {job_id}: {e_write}")
|
||||
structure_data["saved_pdb_path"] = None
|
||||
|
||||
# REMOVED FILE SAVING LOGIC HERE
|
||||
# structure_data["saved_pdb_path"] = ... (NO LONGER SET HERE)
|
||||
|
||||
results["structures"].append(structure_data)
|
||||
|
||||
if results["structures"]:
|
||||
logger.info(f"Successfully processed and calculated pLDDTs for {len(results['structures'])} structures for job {job_id}.")
|
||||
else:
|
||||
logger.warning(f"No structures were processed for job {job_id}.")
|
||||
# Ensure a consistent return structure even if no structures processed
|
||||
return {"success": True, "message": "No PDB structures found in API response to process.", "structures": []}
|
||||
|
||||
return results
|
||||
|
||||
return results # This will contain {"structures": [...list of dicts with pdb_content and scores...]}
|
||||
|
||||
async def call_alphafold2_multimer(
|
||||
sequences: List[str],
|
||||
|
|
@ -177,13 +143,15 @@ async def call_alphafold2_multimer(
|
|||
url: str = DEFAULT_URL,
|
||||
status_url: str = DEFAULT_STATUS_URL,
|
||||
polling_interval: int = 30,
|
||||
timeout: int = 3600,
|
||||
output_dir: Optional[Path] = None
|
||||
timeout: int = 3600
|
||||
# REMOVED output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Call the NVIDIA NIM AlphaFold2-Multimer API.
|
||||
The API now returns JSON with a list of PDB strings, which we process to calculate pLDDT scores.
|
||||
Returns a dictionary with processed PDB strings and computed scores.
|
||||
The API returns JSON with a list of PDB strings.
|
||||
This function processes them to calculate pLDDT scores and returns a dictionary
|
||||
containing a list of structures, each with its PDB content and computed scores.
|
||||
File saving is handled by the caller (ToolExecutor).
|
||||
"""
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
|
|
@ -213,28 +181,33 @@ async def call_alphafold2_multimer(
|
|||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info("AlphaFold2-Multimer job completed synchronously.")
|
||||
logger.info(f"SYNC Final response headers: {response.headers}")
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
logger.info(f"SYNC Final response content-type: {content_type}")
|
||||
|
||||
if "application/json" in content_type:
|
||||
api_response_json_payload = await response.json()
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
# Handle case where API might return {"error": ...} directly in a sync 200
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Sync API call returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
return {"success": False, "error": "Sync JSON response not a list of PDBs as expected."}
|
||||
req_id_sync = response.headers.get("nvcf-reqid", "sync_job") # Get req_id or make one up
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
|
||||
req_id_sync = response.headers.get("nvcf-reqid", "sync_job")
|
||||
return await _process_pdb_and_scores_from_api( # No output_dir
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id_sync,
|
||||
api_response_json=None,
|
||||
output_dir=output_dir
|
||||
api_response_json=None # Not used currently
|
||||
)
|
||||
else:
|
||||
return {"success": False, "error": f"Sync response unexpected content type: {content_type}"}
|
||||
err_text = await response.text()
|
||||
logger.error(f"Sync response unexpected content type: {content_type}. Response: {err_text[:500]}")
|
||||
return {"success": False, "error": f"Sync response unexpected content type: {content_type}", "detail": err_text}
|
||||
|
||||
elif response.status == 202:
|
||||
req_id = response.headers.get("nvcf-reqid")
|
||||
if req_id:
|
||||
logger.info(f"AlphaFold2-Multimer job submitted, request ID: {req_id}")
|
||||
return await _poll_job_status(
|
||||
return await _poll_job_status( # No output_dir
|
||||
req_id=req_id,
|
||||
headers=headers,
|
||||
status_url=status_url,
|
||||
|
|
@ -262,12 +235,14 @@ async def _poll_job_status(
|
|||
status_url: str,
|
||||
polling_interval: int = 30,
|
||||
overall_timeout: int = 3600
|
||||
# REMOVED output_dir: Optional[Path] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
per_status_request_timeout = 600
|
||||
logger.info(f"Polling job {req_id}. Individual status check timeout: {per_status_request_timeout}s, Polling interval: {polling_interval}s, Overall timeout: {overall_timeout}s")
|
||||
|
||||
while True:
|
||||
# ... (polling logic for time checks remains the same) ...
|
||||
current_loop_time = asyncio.get_event_loop().time()
|
||||
elapsed_time = current_loop_time - start_time
|
||||
|
||||
|
|
@ -290,37 +265,33 @@ async def _poll_job_status(
|
|||
headers=headers,
|
||||
timeout=current_status_check_timeout
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
if response.status == 200: # Job completed
|
||||
logger.info(f"AlphaFold2-Multimer job {req_id} completed (status 200).")
|
||||
logger.info(f"FINAL 200 OK Response Headers for job {req_id}: {response.headers}")
|
||||
logger.info(f"FINAL 200 OK Content-Type for job {req_id}: {response.content_type}")
|
||||
|
||||
if response.content_type == 'application/json':
|
||||
try:
|
||||
api_response_json_payload = await response.json()
|
||||
logger.debug(f"API JSON Response for job {req_id}: {str(api_response_json_payload)[:500]}...")
|
||||
|
||||
if not isinstance(api_response_json_payload, list):
|
||||
# Handle case where API might return {"error": ...} directly
|
||||
if isinstance(api_response_json_payload, dict) and "error" in api_response_json_payload:
|
||||
logger.error(f"Job {req_id}: API returned error: {api_response_json_payload['error']}")
|
||||
return {"success": False, "error": api_response_json_payload['error'], "detail": api_response_json_payload.get("detail","")}
|
||||
logger.error(f"Job {req_id}: Expected API response to be a list of PDB strings, got {type(api_response_json_payload)}.")
|
||||
return {"success": False, "error": "API response was not a list of PDB strings."}
|
||||
|
||||
return await _process_pdb_and_scores_from_api(
|
||||
return await _process_pdb_and_scores_from_api( # No output_dir
|
||||
pdb_contents=api_response_json_payload,
|
||||
job_id=req_id,
|
||||
api_response_json=None,
|
||||
output_dir=output_dir
|
||||
api_response_json=None # Not used currently
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Job {req_id}: Failed to decode JSON response from API.", exc_info=True)
|
||||
raw_text = await response.text()
|
||||
logger.debug(f"Raw text response: {raw_text[:500]}")
|
||||
return {"success": False, "error": "Failed to decode JSON response."}
|
||||
return {"success": False, "error": "Failed to decode JSON response.", "detail": raw_text[:500]}
|
||||
else:
|
||||
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json.")
|
||||
raw_text = await response.text()
|
||||
logger.debug(f"Raw text response: {raw_text[:500]}")
|
||||
return {"success": False, "error": f"Unexpected content type: {response.content_type}"}
|
||||
|
||||
logger.error(f"Job {req_id}: Unexpected content type {response.content_type}. Expected application/json. Response: {raw_text[:500]}")
|
||||
return {"success": False, "error": f"Unexpected content type: {response.content_type}", "detail": raw_text}
|
||||
# ... (rest of polling logic: 202, errors, timeouts remains the same) ...
|
||||
elif response.status == 202:
|
||||
try:
|
||||
job_status_json = await response.json()
|
||||
|
|
@ -340,10 +311,10 @@ async def _poll_job_status(
|
|||
return {"success": False, "error": f"Status check failed with HTTP {response.status}", "detail": text}
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Client-side timeout ({current_status_check_timeout}s) during status check for job {req_id}. Retrying poll after {polling_interval}s sleep.")
|
||||
await asyncio.sleep(polling_interval)
|
||||
await asyncio.sleep(polling_interval) # Sleep before next attempt
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Client error polling job status for {req_id}: {e}. Retrying poll after {polling_interval}s.", exc_info=True)
|
||||
await asyncio.sleep(polling_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error polling job status {req_id}: {e}", exc_info=True)
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
return {"success": False, "error": f"Unexpected polling error: {str(e)}"}
|
||||
|
|
@ -29,7 +29,7 @@ class ToolExecutor:
|
|||
if not contigs_str: return "Contigs string is empty."
|
||||
|
||||
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
|
||||
active_contig_parts = contigs_str.split('/')
|
||||
active_contig_parts = contigs_str.split('/') # Split by binder definition markers
|
||||
|
||||
for part in active_contig_parts:
|
||||
chain_segments_in_part = part.strip().split(' ')
|
||||
|
|
@ -97,8 +97,9 @@ class ToolExecutor:
|
|||
|
||||
if self.debug_protein_design_calls:
|
||||
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
fixed_pdb_path = project_root / "binder_outputs" / "target.pdb"
|
||||
# Use a relative path within the package
|
||||
module_dir = Path(__file__).parent
|
||||
fixed_pdb_path = module_dir / "debug_target.pdb"
|
||||
|
||||
if not fixed_pdb_path.exists():
|
||||
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
|
||||
|
|
@ -323,10 +324,6 @@ class ToolExecutor:
|
|||
|
||||
|
||||
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
|
||||
"""
|
||||
Runs AlphaFold2-Multimer to evaluate the target-binder complex. Returns structured output
|
||||
with tool_output and state_updates separated.
|
||||
"""
|
||||
item_id = workflow_state["item_id"]
|
||||
current_internal_step = workflow_state["current_internal_step"]
|
||||
target_seq = workflow_state.get("target_sequence")
|
||||
|
|
@ -349,98 +346,136 @@ class ToolExecutor:
|
|||
tool_output = {"success": False, "error": error_msg}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
relax = args.get("relax_prediction", True) # Added to use LLM arg
|
||||
relax = args.get("relax_prediction", True)
|
||||
|
||||
if self.debug_protein_design_calls:
|
||||
self._debug_af2m_call_count += 1
|
||||
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
|
||||
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})"
|
||||
|
||||
# In debug mode, ToolExecutor still handles file saving
|
||||
debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb"
|
||||
debug_pdb_path = self.output_dir / debug_pdb_filename
|
||||
try:
|
||||
with open(debug_pdb_path, "w") as f:
|
||||
f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
|
||||
logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}")
|
||||
state_updates["complex_pdb_content_path"] = str(debug_pdb_path)
|
||||
except IOError as e:
|
||||
logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}")
|
||||
# If saving fails, don't set the path, but can still proceed with mock pLDDT
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
|
||||
mock_pdb_path = self.output_dir / f"mock_complex_{item_id}_s{current_internal_step}_af2m.pdb"
|
||||
with open(mock_pdb_path, "w") as f:
|
||||
f.write(f"MOCK PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
|
||||
|
||||
state_updates["complex_pdb_content_path"] = str(mock_pdb_path)
|
||||
state_updates["af2_multimer_plddt"] = mock_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
tool_output = {
|
||||
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
|
||||
"plddt": mock_plddt,
|
||||
"complex_file_path": str(mock_pdb_path)
|
||||
"complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
output_subdir = self.output_dir / f"alphafold2_multimer_{item_id}_s{current_internal_step}"
|
||||
logger.info(f"Using output directory for AlphaFold2-Multimer results: {output_subdir}")
|
||||
|
||||
# Call AF2-Multimer - no output_dir passed here
|
||||
api_result = await call_alphafold2_multimer(
|
||||
sequences=all_input_sequences_for_multimer,
|
||||
sequences=all_input_sequences_for_multimer,
|
||||
api_key=self.nim_api_key,
|
||||
relax_prediction=relax,
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval,
|
||||
output_dir=output_subdir
|
||||
timeout=self.api_timeout,
|
||||
polling_interval=self.polling_interval
|
||||
)
|
||||
|
||||
if isinstance(api_result, dict):
|
||||
if "success" in api_result and api_result["success"] is False:
|
||||
error_detail = api_result.get("error", "AF2-Multimer call failed with error.")
|
||||
# Check for explicit failure from the call_alphafold2_multimer function
|
||||
if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False):
|
||||
error_detail = "AF2-Multimer call failed or returned None."
|
||||
if isinstance(api_result, dict):
|
||||
error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.")
|
||||
detail_info = api_result.get("detail", "")
|
||||
if detail_info:
|
||||
error_detail += f" Details: {detail_info}"
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
if detail_info: error_detail += f" Details: {detail_info}"
|
||||
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
# api_result should now be like: {"structures": [{"model_index": ..., "pdb_content": "...", "average_plddt": ...}, ...]}
|
||||
# or {"success": True, "structures": [...]}
|
||||
|
||||
all_structures_info = api_result.get("structures")
|
||||
if not all_structures_info or not isinstance(all_structures_info, list):
|
||||
# This case covers if _process_pdb_and_scores_from_api returned success but empty structures
|
||||
# or if the structure of api_result is unexpected
|
||||
message = api_result.get("message", "No structures returned from AF2-Multimer process.")
|
||||
logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}")
|
||||
if not all_structures_info and isinstance(all_structures_info, list): # Empty list of structures
|
||||
tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None}
|
||||
state_updates["af2_multimer_plddt"] = 0.0
|
||||
state_updates["complex_evaluated"] = True # Evaluated, but with no result
|
||||
state_updates["complex_pdb_content_path"] = None
|
||||
else: # Malformed result
|
||||
tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
if "structures" in api_result and len(api_result["structures"]) > 0:
|
||||
all_structures_info = api_result["structures"]
|
||||
|
||||
best_structure_info = None
|
||||
highest_plddt = -1.0
|
||||
best_structure_info = None
|
||||
highest_plddt = -1.0
|
||||
|
||||
for struct_info in all_structures_info:
|
||||
current_plddt = struct_info.get("average_plddt", 0.0)
|
||||
if current_plddt > highest_plddt:
|
||||
highest_plddt = current_plddt
|
||||
best_structure_info = struct_info
|
||||
for struct_info in all_structures_info:
|
||||
current_plddt = struct_info.get("average_plddt", 0.0)
|
||||
if current_plddt > highest_plddt:
|
||||
highest_plddt = current_plddt
|
||||
best_structure_info = struct_info
|
||||
|
||||
if best_structure_info is None: # Should not happen if all_structures_info was not empty
|
||||
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results.")
|
||||
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
if best_structure_info is None:
|
||||
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.")
|
||||
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
|
||||
state_updates["complex_evaluated"] = False
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
best_plddt = best_structure_info.get("average_plddt", 0.0)
|
||||
best_pdb_path = best_structure_info.get("saved_pdb_path")
|
||||
best_model_idx = best_structure_info.get("model_index", "N/A")
|
||||
# Now, save the PDB content of the best structure
|
||||
best_pdb_content = best_structure_info.get("pdb_content")
|
||||
best_plddt = best_structure_info.get("average_plddt", 0.0) # Should be same as highest_plddt
|
||||
best_model_idx = best_structure_info.get("model_index", "NA") # Use NA if not found
|
||||
|
||||
state_updates["complex_pdb_content_path"] = best_pdb_path
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
if not best_pdb_content:
|
||||
logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.")
|
||||
tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."}
|
||||
state_updates["complex_evaluated"] = False # Or True with pLDDT, but no path
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models. PDB: {best_pdb_path}")
|
||||
# Construct filename and save
|
||||
complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb"
|
||||
complex_pdb_path = self.output_dir / complex_pdb_filename
|
||||
|
||||
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
|
||||
try:
|
||||
with open(complex_pdb_path, "w", encoding='utf-8') as f:
|
||||
f.write(best_pdb_content)
|
||||
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}")
|
||||
|
||||
state_updates["complex_pdb_content_path"] = str(complex_pdb_path)
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_evaluated"] = True
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": complex_quality_message,
|
||||
"plddt": best_plddt,
|
||||
"complex_file_path": best_pdb_path,
|
||||
"selected_model_index": best_model_idx
|
||||
}
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
error_detail = "AF2-Multimer call failed or returned unexpected data format."
|
||||
if isinstance(api_result, dict) and "error" in api_result:
|
||||
error_detail = api_result["error"]
|
||||
|
||||
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. Full API Result: {api_result}")
|
||||
tool_output = {"success": False, "error": error_detail}
|
||||
state_updates["complex_evaluated"] = False
|
||||
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
|
||||
|
||||
tool_output = {
|
||||
"success": True,
|
||||
"message": complex_quality_message,
|
||||
"plddt": best_plddt,
|
||||
"complex_file_path": str(complex_pdb_path),
|
||||
"selected_model_index": best_model_idx
|
||||
}
|
||||
except IOError as e:
|
||||
logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}")
|
||||
tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"}
|
||||
# Still record the pLDDT for reward, even if saving failed
|
||||
state_updates["af2_multimer_plddt"] = best_plddt
|
||||
state_updates["complex_pdb_content_path"] = None # Path is not valid
|
||||
state_updates["complex_evaluated"] = True # It was evaluated, saving failed
|
||||
|
||||
return {"tool_output": tool_output, "state_updates": state_updates}
|
||||
|
||||
|
||||
|
|
@ -469,4 +504,4 @@ class ToolExecutor:
|
|||
return {
|
||||
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"},
|
||||
"state_updates": {}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue