refactor file saving from alphafold2_multimer to tool_executor

This commit is contained in:
based-tachikoma 2025-05-21 21:23:00 -07:00
parent 227e594ebf
commit 02585947e4
2 changed files with 162 additions and 156 deletions

View file

@ -29,7 +29,7 @@ class ToolExecutor:
if not contigs_str: return "Contigs string is empty."
target_segment_pattern = re.compile(r"([A-Za-z0-9])(\d+)-(\d+)")
active_contig_parts = contigs_str.split('/')
active_contig_parts = contigs_str.split('/') # Split by binder definition markers
for part in active_contig_parts:
chain_segments_in_part = part.strip().split(' ')
@ -97,8 +97,9 @@ class ToolExecutor:
if self.debug_protein_design_calls:
logger.warning(f"DEBUG MODE: Bypassing AlphaFold2 API call for workflow {item_id}.")
project_root = Path(__file__).parent.parent.parent.parent
fixed_pdb_path = project_root / "binder_outputs" / "target.pdb"
# Use a relative path within the package
module_dir = Path(__file__).parent
fixed_pdb_path = module_dir / "debug_target.pdb"
if not fixed_pdb_path.exists():
logger.error(f"Debug mode failed: {fixed_pdb_path} not found.")
@ -323,10 +324,6 @@ class ToolExecutor:
async def _run_nim_af2_multimer(self, args: Dict, workflow_state: Dict) -> Dict:
"""
Runs AlphaFold2-Multimer to evaluate the target-binder complex. Returns structured output
with tool_output and state_updates separated.
"""
item_id = workflow_state["item_id"]
current_internal_step = workflow_state["current_internal_step"]
target_seq = workflow_state.get("target_sequence")
@ -349,98 +346,136 @@ class ToolExecutor:
tool_output = {"success": False, "error": error_msg}
return {"tool_output": tool_output, "state_updates": state_updates}
relax = args.get("relax_prediction", True) # Added to use LLM arg
relax = args.get("relax_prediction", True)
if self.debug_protein_design_calls:
self._debug_af2m_call_count += 1
mock_plddt = 87.5 if self._debug_af2m_call_count % 2 == 1 else 45.2
success_message = f"DEBUG MODE: Returning {'high' if mock_plddt > 50 else 'low'}-quality mock results (call #{self._debug_af2m_call_count})"
# In debug mode, ToolExecutor still handles file saving
debug_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_DEBUG_pLDDT{mock_plddt:.2f}.pdb"
debug_pdb_path = self.output_dir / debug_pdb_filename
try:
with open(debug_pdb_path, "w") as f:
f.write(f"REMARK DEBUG PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
logger.info(f"DEBUG MODE: Saved mock AF2-Multimer PDB to {debug_pdb_path}")
state_updates["complex_pdb_content_path"] = str(debug_pdb_path)
except IOError as e:
logger.error(f"DEBUG MODE: Failed to write mock PDB {debug_pdb_path}: {e}")
# If saving fails, don't set the path, but can still proceed with mock pLDDT
state_updates["complex_pdb_content_path"] = None
mock_pdb_path = self.output_dir / f"mock_complex_{item_id}_s{current_internal_step}_af2m.pdb"
with open(mock_pdb_path, "w") as f:
f.write(f"MOCK PDB FILE for complex. Predicted pLDDT {mock_plddt}\n")
state_updates["complex_pdb_content_path"] = str(mock_pdb_path)
state_updates["af2_multimer_plddt"] = mock_plddt
state_updates["complex_evaluated"] = True
tool_output = {
"success": True, "message": f"{success_message}. Mock pLDDT: {mock_plddt:.2f}",
"plddt": mock_plddt,
"complex_file_path": str(mock_pdb_path)
"complex_file_path": str(debug_pdb_path) if state_updates["complex_pdb_content_path"] else None
}
return {"tool_output": tool_output, "state_updates": state_updates}
output_subdir = self.output_dir / f"alphafold2_multimer_{item_id}_s{current_internal_step}"
logger.info(f"Using output directory for AlphaFold2-Multimer results: {output_subdir}")
# Call AF2-Multimer - no output_dir passed here
api_result = await call_alphafold2_multimer(
sequences=all_input_sequences_for_multimer,
sequences=all_input_sequences_for_multimer,
api_key=self.nim_api_key,
relax_prediction=relax,
timeout=self.api_timeout,
polling_interval=self.polling_interval,
output_dir=output_subdir
timeout=self.api_timeout,
polling_interval=self.polling_interval
)
if isinstance(api_result, dict):
if "success" in api_result and api_result["success"] is False:
error_detail = api_result.get("error", "AF2-Multimer call failed with error.")
# Check for explicit failure from the call_alphafold2_multimer function
if api_result is None or (isinstance(api_result, dict) and api_result.get("success") is False):
error_detail = "AF2-Multimer call failed or returned None."
if isinstance(api_result, dict):
error_detail = api_result.get("error", "AF2-Multimer call failed with unspecified error.")
detail_info = api_result.get("detail", "")
if detail_info:
error_detail += f" Details: {detail_info}"
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}")
tool_output = {"success": False, "error": error_detail}
if detail_info: error_detail += f" Details: {detail_info}"
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. API Result: {api_result}")
tool_output = {"success": False, "error": error_detail}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
# api_result should now be like: {"structures": [{"model_index": ..., "pdb_content": "...", "average_plddt": ...}, ...]}
# or {"success": True, "structures": [...]}
all_structures_info = api_result.get("structures")
if not all_structures_info or not isinstance(all_structures_info, list):
# This case covers if _process_pdb_and_scores_from_api returned success but empty structures
# or if the structure of api_result is unexpected
message = api_result.get("message", "No structures returned from AF2-Multimer process.")
logger.warning(f"Workflow {item_id}: {message}. API Result: {api_result}")
if not all_structures_info and isinstance(all_structures_info, list): # Empty list of structures
tool_output = {"success": True, "message": "AF2-Multimer ran, but no PDB structures were produced by the API.", "plddt": 0.0, "complex_file_path": None}
state_updates["af2_multimer_plddt"] = 0.0
state_updates["complex_evaluated"] = True # Evaluated, but with no result
state_updates["complex_pdb_content_path"] = None
else: # Malformed result
tool_output = {"success": False, "error": "AF2-Multimer returned unexpected data or no structures."}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
return {"tool_output": tool_output, "state_updates": state_updates}
if "structures" in api_result and len(api_result["structures"]) > 0:
all_structures_info = api_result["structures"]
best_structure_info = None
highest_plddt = -1.0
best_structure_info = None
highest_plddt = -1.0
for struct_info in all_structures_info:
current_plddt = struct_info.get("average_plddt", 0.0)
if current_plddt > highest_plddt:
highest_plddt = current_plddt
best_structure_info = struct_info
for struct_info in all_structures_info:
current_plddt = struct_info.get("average_plddt", 0.0)
if current_plddt > highest_plddt:
highest_plddt = current_plddt
best_structure_info = struct_info
if best_structure_info is None: # Should not happen if all_structures_info was not empty
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results.")
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
if best_structure_info is None:
logger.error(f"Workflow {item_id}: No valid structure with pLDDT found in AF2-Multimer results, though structures were present.")
tool_output = {"success": False, "error": "No valid structure with pLDDT in AF2-Multimer results."}
state_updates["complex_evaluated"] = False
return {"tool_output": tool_output, "state_updates": state_updates}
best_plddt = best_structure_info.get("average_plddt", 0.0)
best_pdb_path = best_structure_info.get("saved_pdb_path")
best_model_idx = best_structure_info.get("model_index", "N/A")
# Now, save the PDB content of the best structure
best_pdb_content = best_structure_info.get("pdb_content")
best_plddt = best_structure_info.get("average_plddt", 0.0) # Should be same as highest_plddt
best_model_idx = best_structure_info.get("model_index", "NA") # Use NA if not found
state_updates["complex_pdb_content_path"] = best_pdb_path
state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_evaluated"] = True
if not best_pdb_content:
logger.error(f"Workflow {item_id}: Best AF2-Multimer structure (Model {best_model_idx}, pLDDT {best_plddt:.2f}) found, but PDB content is missing.")
tool_output = {"success": False, "error": f"Best model (pLDDT {best_plddt:.2f}) has no PDB content."}
state_updates["complex_evaluated"] = False # Or True with pLDDT, but no path
state_updates["af2_multimer_plddt"] = best_plddt
return {"tool_output": tool_output, "state_updates": state_updates}
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models. PDB: {best_pdb_path}")
# Construct filename and save
complex_pdb_filename = f"complex_{item_id}_s{current_internal_step}_af2m_model{best_model_idx}_pLDDT{best_plddt:.2f}.pdb"
complex_pdb_path = self.output_dir / complex_pdb_filename
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
try:
with open(complex_pdb_path, "w", encoding='utf-8') as f:
f.write(best_pdb_content)
logger.info(f"Workflow {item_id}: AlphaFold2-Multimer complete. Saved best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f} from {len(all_structures_info)} models to {complex_pdb_path}")
state_updates["complex_pdb_content_path"] = str(complex_pdb_path)
state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_evaluated"] = True
tool_output = {
"success": True,
"message": complex_quality_message,
"plddt": best_plddt,
"complex_file_path": best_pdb_path,
"selected_model_index": best_model_idx
}
return {"tool_output": tool_output, "state_updates": state_updates}
error_detail = "AF2-Multimer call failed or returned unexpected data format."
if isinstance(api_result, dict) and "error" in api_result:
error_detail = api_result["error"]
logger.error(f"Workflow {item_id}: AF2-Multimer call failed: {error_detail}. Full API Result: {api_result}")
tool_output = {"success": False, "error": error_detail}
state_updates["complex_evaluated"] = False
complex_quality_message = f"AlphaFold2-Multimer evaluation complete. Selected best model (Index {best_model_idx}) with pLDDT: {best_plddt:.2f}"
tool_output = {
"success": True,
"message": complex_quality_message,
"plddt": best_plddt,
"complex_file_path": str(complex_pdb_path),
"selected_model_index": best_model_idx
}
except IOError as e:
logger.error(f"Workflow {item_id}: Failed to save best AF2-Multimer PDB (Model {best_model_idx}, pLDDT {best_plddt:.2f}) to {complex_pdb_path}: {e}")
tool_output = {"success": False, "error": f"Failed to save best complex PDB: {e}"}
# Still record the pLDDT for reward, even if saving failed
state_updates["af2_multimer_plddt"] = best_plddt
state_updates["complex_pdb_content_path"] = None # Path is not valid
state_updates["complex_evaluated"] = True # It was evaluated, saving failed
return {"tool_output": tool_output, "state_updates": state_updates}
@ -469,4 +504,4 @@ class ToolExecutor:
return {
"tool_output": {"success": False, "error": f"Unknown tool name: {tool_name}"},
"state_updates": {}
}
}