mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Ethereum Virtual Machine Text to Transaction Environment (#187)
* EVM-text_to_transaction * update structure * Update README --------- Co-authored-by: Jeremy Melvin <jeremy@openblocklabs.com>
This commit is contained in:
parent
d0a253e1b5
commit
3bed7c64b9
8 changed files with 2135 additions and 0 deletions
779
environments/community/ethereum_virtual_machine/evm_server.py
Normal file
779
environments/community/ethereum_virtual_machine/evm_server.py
Normal file
|
|
@ -0,0 +1,779 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
EVM Environment for Atropos: Ethereum Virtual Machine Transaction Agent Training
|
||||
|
||||
This environment trains language models to generate and execute profitable Ethereum transactions
|
||||
using Anvil (Foundry's local blockchain simulation).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from anvil import AnvilBackend, AnvilConfig
|
||||
from evm_config import EVMEnvConfig
|
||||
from openai import OpenAI
|
||||
from utils import cleanup_blockchain, cleanup_manager, setup_evm_error_message
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# Add logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# System prompt for EVM transaction agent
|
||||
system_prompt = (
|
||||
"You are a deep thinking AI, you may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic "
|
||||
"reasoning processes to help come to a correct solution prior to answering. "
|
||||
"You should enclose your thoughts and internal monologue inside <think> </think> "
|
||||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
system_prompt += """You are allowed to use a maximum of 2048 tokens. Please strive to use less.
|
||||
|
||||
You are here to assist a user execute transfers of both ETH and ERC-20 tokens as requested.
|
||||
Your job is to generate correct Ethereum transaction data for the requested action.
|
||||
|
||||
IMPORTANT: After your thinking, your response must include a valid JSON transaction object:
|
||||
{"to": "0x...", "value": "amount_in_wei", "data": "0x..."}
|
||||
|
||||
- 'to': The recipient address (contract or EOA)
|
||||
- 'value': Amount of ETH to send in wei (string)
|
||||
- 'data': Transaction data
|
||||
|
||||
If you do not provide a valid JSON transaction object, your submission will be ignored and you \
|
||||
will receive a score of -1.0.
|
||||
|
||||
Example 1:
|
||||
{
|
||||
"to": "0xe688b84b23f322a994A53dbF8E15FA82CDB71127",
|
||||
"value": "0.01",
|
||||
"data": "0x"
|
||||
}
|
||||
|
||||
Example 2:
|
||||
{
|
||||
"to": "0xEA29e9da69317d80075fBfc836E843C6d65971F5",
|
||||
"value": "0x",
|
||||
"data": "0xa9059cbb000000000000000000000000ea29e9da69317d80075fbfc836e843c6d65971f50000000000000000000000000000000000000000000000000000000005f5e100" # noqa: E501
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class EVMEnv(BaseEnv):
|
||||
"""EVM Transaction Environment for training agents to interact with Ethereum"""
|
||||
|
||||
name = "evm_agent"
|
||||
env_config_cls = EVMEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EVMEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
"""Initialize the EVM environment"""
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set up minimal logging - only for essential operations
|
||||
self.logger = logging.getLogger(f"{self.__class__.__name__}")
|
||||
self.logger.setLevel(logging.WARNING) # Only warnings and errors
|
||||
self.logger.propagate = False
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(message)s") # Clean format
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
|
||||
# Suppress base environment logs
|
||||
if config.suppress_base_env_logs:
|
||||
base_logger = logging.getLogger("atroposlib.envs.base")
|
||||
base_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Load Anvil configuration
|
||||
self.anvil_config = AnvilConfig(config.anvil_config_path)
|
||||
|
||||
# Initialize blockchain handler
|
||||
self.blockchain = AnvilBackend(self.anvil_config)
|
||||
|
||||
# Performance tracking for adaptive question selection
|
||||
self.question_performance = {qtype: [] for qtype in config.question_types}
|
||||
self.current_question_type = None
|
||||
|
||||
# Store current prompt data for scoring
|
||||
self.current_prompt_data = None
|
||||
|
||||
# Register cleanup with the global cleanup manager
|
||||
cleanup_manager.register_cleanup(cleanup_blockchain, self.blockchain)
|
||||
|
||||
async def setup(self):
|
||||
"""Setup the EVM environment and start Anvil"""
|
||||
try:
|
||||
print("Starting Anvil blockchain simulation...")
|
||||
self.blockchain.start()
|
||||
self.blockchain.setup_wallet()
|
||||
print("EVM environment setup completed successfully.")
|
||||
except Exception as e:
|
||||
error_message = setup_evm_error_message(self.anvil_config, e)
|
||||
print(error_message)
|
||||
|
||||
# Cleanup and exit
|
||||
cleanup_blockchain(self.blockchain)
|
||||
sys.exit(1)
|
||||
|
||||
async def get_next_item(self) -> Optional[Item]:
|
||||
"""Generate the next transaction challenge for the agent"""
|
||||
try:
|
||||
# Select question type based on performance (exploration vs exploitation)
|
||||
question_type = self._select_question_type()
|
||||
self.current_question_type = question_type
|
||||
|
||||
# Generate question prompt and get structured data
|
||||
prompt_text, prompt_data = await self._generate_question_prompt(
|
||||
question_type
|
||||
)
|
||||
|
||||
# Store the prompt data for scoring
|
||||
self.current_prompt_data = prompt_data
|
||||
|
||||
# Display Generated Input
|
||||
self.logger.debug("\n=== Generated Input ===")
|
||||
self.logger.debug(prompt_text)
|
||||
self.logger.debug("=" * 50)
|
||||
|
||||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": prompt_text}.items())]
|
||||
)
|
||||
|
||||
return (prompt, None, None)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in get_next_item: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _select_question_type(self) -> str:
|
||||
"""Select question type using weakness-targeting strategy with 80/20 ratio"""
|
||||
# If no performance data yet, select randomly
|
||||
if not any(self.question_performance.values()):
|
||||
return random.choice(self.config.question_types)
|
||||
|
||||
# Calculate average scores for each question type
|
||||
avg_scores = {}
|
||||
for qtype, scores in self.question_performance.items():
|
||||
if scores:
|
||||
avg_scores[qtype] = sum(scores) / len(scores)
|
||||
else:
|
||||
avg_scores[qtype] = 0.0 # Prioritize untested question types
|
||||
|
||||
# Split into weak and strong areas based on configurable performance threshold
|
||||
weak_threshold = self.config.weak_performance_threshold
|
||||
|
||||
weak_qtypes = [
|
||||
qtype for qtype, score in avg_scores.items() if score < weak_threshold
|
||||
]
|
||||
strong_qtypes = [
|
||||
qtype for qtype, score in avg_scores.items() if score >= weak_threshold
|
||||
]
|
||||
|
||||
# Configurable focus on weak areas vs strong areas for mastery maintenance
|
||||
if random.random() < self.config.weak_area_focus_ratio and weak_qtypes:
|
||||
selected_type = random.choice(weak_qtypes)
|
||||
elif strong_qtypes:
|
||||
selected_type = random.choice(strong_qtypes)
|
||||
else:
|
||||
selected_type = random.choice(list(avg_scores.keys()))
|
||||
|
||||
return selected_type
|
||||
|
||||
async def _generate_question_prompt(
|
||||
self, question_type: str
|
||||
) -> Tuple[str, Optional[Dict[str, Any]]]:
|
||||
"""Generate a dynamic question prompt using LLM based on the question type"""
|
||||
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# Create prompt for LLM to generate a request
|
||||
llm_prompt = f"""You are generating a natural language transaction request for an Ethereum blockchain agent.
|
||||
|
||||
TRANSACTION TYPE: "{question_type}"
|
||||
|
||||
CONTEXT:
|
||||
- Wallet Address: {self.anvil_config.funding.custom_wallet}
|
||||
- Current Balances: {json.dumps(self.blockchain.get_wallet_balances(), indent=2)}
|
||||
|
||||
TASK:
|
||||
Generate a realistic, conversational transaction request that:
|
||||
1. Matches the specified transaction type exactly
|
||||
2. Does not use more than current wallet balances and typically would be small transfers or possibly larger,
|
||||
like how a real person may use their assets
|
||||
3. Includes all necessary details (token, amount, destination address)
|
||||
4. Sounds like how a real user would naturally request a transaction
|
||||
5. Varies in tone and style (casual, formal, urgent, etc.)
|
||||
|
||||
REQUIREMENTS:
|
||||
- Use realistic destination addresses (not placeholder text like
|
||||
"0x1234567890123456789012345678901234567890")
|
||||
- Does not specify amounts larger than 50% of the current balance
|
||||
- Make the request executable
|
||||
|
||||
Generate ONE natural language request that matches the transaction type "{question_type}".
|
||||
Respond with a JSON object with the following fields:
|
||||
- question_type: The type of transaction to generate
|
||||
- request: The natural language request text
|
||||
- destination_address: The destination address
|
||||
- transfer_token: The token to transfer
|
||||
- transfer_amount: The amount to transfer
|
||||
|
||||
Examples:
|
||||
1. ETH transfer
|
||||
{{
|
||||
"question_type": "ETH transfer",
|
||||
"request": "yo, can i send 0.01 ETH to my buddy jasper? His address is jasper.eth"
|
||||
"destination_address": "jasper.eth"
|
||||
"transfer_token": "ETH"
|
||||
"transfer_amount": "0.01"
|
||||
}}
|
||||
2. ERC-20 transfer using 18 decimal token
|
||||
{{
|
||||
"question_type": "ERC-20 transfer using 18 decimal token",
|
||||
"request": "Send 100 CRV to 0xe688b84b23f322a994A53dbF8E15FA82CDB71127"
|
||||
"destination_address": "0xe688b84b23f322a994A53dbF8E15FA82CDB71127"
|
||||
"transfer_token": "CRV"
|
||||
"transfer_amount": "100"
|
||||
}}
|
||||
3. ERC-20 transfer using a non-18 decimal token (e.g. USDT)
|
||||
{{
|
||||
"question_type": "ERC-20 transfer using a non-18 decimal token",
|
||||
"request": "give 100 tether to 0xea29e9da69317d80075fbfc836e843c6d65971f5"
|
||||
"destination_address": "0xea29e9da69317d80075fbfc836e843c6d65971f5"
|
||||
"transfer_token": "USDT"
|
||||
"transfer_amount": "100"
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate multiple responses in a single call for efficiency
|
||||
response = client.chat.completions.create(
|
||||
model=self.config.question_generation_model,
|
||||
messages=[{"role": "user", "content": llm_prompt}],
|
||||
temperature=self.config.question_generation_temperature,
|
||||
max_tokens=self.config.question_generation_max_tokens,
|
||||
n=self.config.question_generation_n,
|
||||
)
|
||||
|
||||
# Try each response until we find a valid one
|
||||
for i, choice in enumerate(response.choices):
|
||||
generated_content = choice.message.content.strip()
|
||||
|
||||
# Extract JSON from response using generic function
|
||||
prompt_data = self._extract_json_from_response(
|
||||
generated_content, ["question_type", "request"], "prompt"
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
if prompt_data and self._validate_prompt_data(
|
||||
prompt_data, question_type
|
||||
):
|
||||
return prompt_data["request"], prompt_data
|
||||
|
||||
# All choices failed, use fallback
|
||||
fallback_data = {
|
||||
"question_type": question_type,
|
||||
"request": "Transfer 0.01 ETH to 0x0000000000000000000000000000000000000000",
|
||||
"destination_address": "0x0000000000000000000000000000000000000000",
|
||||
"transfer_token": "ETH",
|
||||
"transfer_amount": "0.01",
|
||||
}
|
||||
return fallback_data["request"], fallback_data
|
||||
|
||||
except Exception:
|
||||
fallback_data = {
|
||||
"question_type": question_type,
|
||||
"request": "Transfer 0.01 ETH to 0x0000000000000000000000000000000000000000",
|
||||
"destination_address": "0x0000000000000000000000000000000000000000",
|
||||
"transfer_token": "ETH",
|
||||
"transfer_amount": "0.01",
|
||||
}
|
||||
return fallback_data["request"], fallback_data
|
||||
|
||||
def _extract_json_from_response(
|
||||
self, response: str, required_keys: List[str], json_type: str = "JSON"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Generic JSON extraction from LLM response, handling thinking tags"""
|
||||
if not isinstance(response, str):
|
||||
return None
|
||||
|
||||
# First, try to extract content after thinking tags (following SWE pattern)
|
||||
content_after_think = response
|
||||
think_end_match = re.search(r"</think>", response, re.IGNORECASE)
|
||||
if think_end_match:
|
||||
content_after_think = response[think_end_match.end() :].strip()
|
||||
|
||||
# Create patterns based on required keys
|
||||
if len(required_keys) >= 2:
|
||||
key1, key2 = required_keys[0], required_keys[1]
|
||||
json_patterns = [
|
||||
rf'\{{[^{{}}]*"{key1}"[^{{}}]*"{key2}"[^{{}}]*\}}', # Simple pattern with first two keys
|
||||
rf'\{{.*?"{key1}".*?"{key2}".*?\}}', # Flexible pattern with first two keys
|
||||
r"\{.*?\}", # Any JSON object
|
||||
]
|
||||
else:
|
||||
json_patterns = [r"\{.*?\}"] # Fallback to any JSON
|
||||
|
||||
for pattern in json_patterns:
|
||||
matches = re.findall(pattern, content_after_think, re.DOTALL)
|
||||
for match in matches:
|
||||
try:
|
||||
# Clean up the JSON string
|
||||
json_str = match.strip()
|
||||
|
||||
# Parse the JSON
|
||||
obj = json.loads(json_str)
|
||||
|
||||
# Verify it has the expected structure
|
||||
if isinstance(obj, dict) and required_keys[0] in obj:
|
||||
return obj
|
||||
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _extract_transaction_json(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""Extract transaction JSON from LLM response"""
|
||||
return self._extract_json_from_response(
|
||||
response, ["to", "value"], "transaction"
|
||||
)
|
||||
|
||||
def _validate_prompt_data(
|
||||
self, prompt_data: Dict[str, Any], expected_question_type: str
|
||||
) -> bool:
|
||||
"""Validate that prompt data has all required fields and correct question type"""
|
||||
required_fields = [
|
||||
"question_type",
|
||||
"request",
|
||||
"destination_address",
|
||||
"transfer_token",
|
||||
"transfer_amount",
|
||||
]
|
||||
|
||||
# Check all required fields are present
|
||||
if not all(field in prompt_data for field in required_fields):
|
||||
return False
|
||||
|
||||
# Check question type matches what we requested
|
||||
if prompt_data["question_type"] != expected_question_type:
|
||||
return False
|
||||
|
||||
# Check that fields are not empty
|
||||
for field in required_fields:
|
||||
if not prompt_data[field] or str(prompt_data[field]).strip() == "":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
"""Collect trajectories by having the agent generate transactions"""
|
||||
to_score = []
|
||||
to_backlog = []
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
|
||||
user_msg = {"role": "user", "content": dict(item[0][0])["content"]}
|
||||
|
||||
messages = [system_msg, user_msg]
|
||||
|
||||
try:
|
||||
# Use proper Atropos framework pattern like humor generation
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
# Store completions for output saving
|
||||
self.last_completions = []
|
||||
|
||||
for i, choice in enumerate(chat_completions.choices):
|
||||
# Store the completion
|
||||
self.last_completions.append(choice.message.content)
|
||||
|
||||
# Display Generated Output
|
||||
self.logger.debug(f"\n=== Generated Output {i+1} ===")
|
||||
self.logger.debug(choice.message.content)
|
||||
self.logger.debug("=" * 50)
|
||||
|
||||
history = [
|
||||
{"role": "system", "content": system_msg["content"]},
|
||||
{"role": "user", "content": user_msg["content"]},
|
||||
{"role": "assistant", "content": choice.message.content},
|
||||
]
|
||||
to_score.append((history, item[1], None))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in collect_trajectories: {e}")
|
||||
traceback.print_exc()
|
||||
to_backlog.append(item)
|
||||
|
||||
if not to_score:
|
||||
return None, to_backlog
|
||||
|
||||
scored_data = await self.score(to_score)
|
||||
return scored_data, to_backlog
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
"""Score the generated transactions by executing them on Anvil"""
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
scores["advantages"] = None
|
||||
scores["ref_logprobs"] = None
|
||||
scores["messages"] = None
|
||||
scores["group_overrides"] = {"group_size": self.config.group_size}
|
||||
scores["overrides"] = None
|
||||
scores["ground_truths"] = []
|
||||
|
||||
for i, item in enumerate(rollout_group_data):
|
||||
out = tokenize_for_trainer(self.tokenizer, item[0])
|
||||
tokens = out["tokens"]
|
||||
masks = out["masks"]
|
||||
|
||||
try:
|
||||
# Extract the agent's response (transaction JSON)
|
||||
agent_response = item[0][-1]["content"].strip()
|
||||
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||
|
||||
# Score the transaction
|
||||
score = await self._score_transaction(agent_response)
|
||||
|
||||
# Display Score
|
||||
self.logger.debug(f"\n=== Score {i+1} ===")
|
||||
self.logger.debug(f"{score}")
|
||||
self.logger.debug("=" * 50)
|
||||
|
||||
# Track performance for this question type
|
||||
if self.current_question_type:
|
||||
self.question_performance[self.current_question_type].append(score)
|
||||
# Keep only last 10 scores per question type
|
||||
if len(self.question_performance[self.current_question_type]) > 10:
|
||||
self.question_performance[self.current_question_type].pop(0)
|
||||
|
||||
except Exception as e:
|
||||
score = -1.0
|
||||
ground_truth = item[1] if isinstance(item[1], str) else ""
|
||||
|
||||
# Display Score for error case
|
||||
print(f"\n=== Score {i+1} ===")
|
||||
print(f"{score} (Error: {e})")
|
||||
print("=" * 50)
|
||||
|
||||
# Skip if too few tokens
|
||||
if len([i for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(score)
|
||||
scores["ground_truths"].append(ground_truth)
|
||||
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
|
||||
if not scores["tokens"]:
|
||||
return None
|
||||
|
||||
return scores
|
||||
|
||||
async def _score_transaction(self, agent_response: str) -> float:
|
||||
"""Score a transaction based on multiple criteria"""
|
||||
try:
|
||||
# First, extract JSON from the response (handling thinking tags)
|
||||
tx_obj = self._extract_transaction_json(agent_response)
|
||||
if tx_obj is None:
|
||||
return -1.0 # Could not extract valid JSON
|
||||
|
||||
# Validate required fields
|
||||
if not all(field in tx_obj for field in ["to", "value", "data"]):
|
||||
return -1.0 # Missing required fields
|
||||
|
||||
# Get expected transfer details from stored prompt data
|
||||
if not hasattr(self, "current_prompt_data") or not self.current_prompt_data:
|
||||
return -1.0
|
||||
|
||||
expected_token = self.current_prompt_data.get("transfer_token", "ETH")
|
||||
expected_amount = self.current_prompt_data.get("transfer_amount", "0")
|
||||
expected_destination = self.current_prompt_data.get(
|
||||
"destination_address", ""
|
||||
)
|
||||
|
||||
# Get sender and destination addresses
|
||||
sender_address = self.anvil_config.funding.custom_wallet
|
||||
destination_address = tx_obj.get("to", "")
|
||||
|
||||
# Get relevant tokens to check
|
||||
relevant_tokens = ["ETH"]
|
||||
if expected_token != "ETH":
|
||||
relevant_tokens.append(expected_token)
|
||||
|
||||
# Take a snapshot before execution
|
||||
snapshot_id = self.blockchain.snapshot()
|
||||
|
||||
# Get pre-execution balances for both addresses
|
||||
pre_balances = {
|
||||
"sender": self.blockchain.get_wallet_balances(
|
||||
sender_address, relevant_tokens
|
||||
),
|
||||
"destination": self.blockchain.get_wallet_balances(
|
||||
destination_address, relevant_tokens
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
# Execute the transaction
|
||||
result = self.blockchain.execute_transaction(tx_obj)
|
||||
|
||||
# Get post-execution balances
|
||||
post_balances = {
|
||||
"sender": self.blockchain.get_wallet_balances(
|
||||
sender_address, relevant_tokens
|
||||
),
|
||||
"destination": self.blockchain.get_wallet_balances(
|
||||
destination_address, relevant_tokens
|
||||
),
|
||||
}
|
||||
|
||||
# Calculate score based on execution result and balance changes
|
||||
score = self._calculate_transaction_score(
|
||||
tx_obj,
|
||||
result,
|
||||
agent_response,
|
||||
pre_balances,
|
||||
post_balances,
|
||||
expected_token,
|
||||
expected_amount,
|
||||
expected_destination,
|
||||
)
|
||||
|
||||
# Revert to snapshot to maintain clean state
|
||||
self.blockchain.revert(snapshot_id)
|
||||
|
||||
return score
|
||||
|
||||
except Exception:
|
||||
# Revert on any error
|
||||
self.blockchain.revert(snapshot_id)
|
||||
return -1.0 # Execution error
|
||||
|
||||
except Exception:
|
||||
return -1.0 # General error
|
||||
|
||||
def _calculate_transaction_score(
|
||||
self,
|
||||
tx_obj: Dict[str, Any],
|
||||
result: Dict[str, Any],
|
||||
agent_response: str,
|
||||
pre_balances: Dict[str, Dict[str, Any]],
|
||||
post_balances: Dict[str, Dict[str, Any]],
|
||||
transfer_token: str,
|
||||
transfer_amount: str,
|
||||
destination_address: str,
|
||||
) -> float:
|
||||
"""Calculate score based on transaction execution results and balance changes"""
|
||||
base_score = 0.0
|
||||
|
||||
# 1. Successful execution (0.3 points)
|
||||
if result.get("status") == "0x1":
|
||||
base_score += 0.3 # Transaction succeeded
|
||||
|
||||
# 2. Correct transaction - exact balance verification (0.5 points)
|
||||
balance_score = self._verify_expected_transfers(
|
||||
pre_balances,
|
||||
post_balances,
|
||||
transfer_token,
|
||||
transfer_amount,
|
||||
destination_address,
|
||||
)
|
||||
base_score += balance_score
|
||||
|
||||
# 3. Thinking quality (max 0.1 points, with negative for missing thinking)
|
||||
thinking_score = self._analyze_thinking_quality(agent_response)
|
||||
base_score += thinking_score # Range: -0.2 to +0.1
|
||||
|
||||
# 4. To field verification (0.05 points)
|
||||
tx_to = tx_obj.get("to", "").lower()
|
||||
expected_to = destination_address.lower()
|
||||
if tx_to == expected_to:
|
||||
base_score += 0.05
|
||||
|
||||
# 5. Data field verification (0.05 points)
|
||||
data_field = tx_obj.get("data", "0x")
|
||||
if transfer_token == "ETH":
|
||||
# ETH transfer should have empty data field
|
||||
if data_field == "0x":
|
||||
base_score += 0.05
|
||||
else:
|
||||
# ERC-20 transfer should have transfer function call data
|
||||
if data_field.startswith("0xa9059cbb"): # transfer function selector
|
||||
base_score += 0.05
|
||||
|
||||
return base_score
|
||||
|
||||
def _verify_expected_transfers(
|
||||
self,
|
||||
pre_balances: Dict[str, Dict[str, Any]],
|
||||
post_balances: Dict[str, Dict[str, Any]],
|
||||
expected_token: str,
|
||||
expected_amount: str,
|
||||
expected_destination: str,
|
||||
) -> float:
|
||||
"""Verify that the expected transfer amounts occurred"""
|
||||
try:
|
||||
expected_amount_float = float(expected_amount)
|
||||
|
||||
# For ETH transfers - only check destination address
|
||||
if expected_token == "ETH":
|
||||
# Extract balance values from the nested dictionary structure
|
||||
dest_pre_balance = (
|
||||
pre_balances["destination"].get("ETH", {}).get("balance", 0)
|
||||
)
|
||||
dest_post_balance = (
|
||||
post_balances["destination"].get("ETH", {}).get("balance", 0)
|
||||
)
|
||||
|
||||
dest_eth_change = float(dest_post_balance) - float(dest_pre_balance)
|
||||
|
||||
# Check if destination gained exactly the expected amount
|
||||
if abs(dest_eth_change - expected_amount_float) < 1e-10:
|
||||
return 0.5
|
||||
|
||||
# For ERC-20 transfers - check both sender and destination
|
||||
else:
|
||||
# Extract balance values from the nested dictionary structure
|
||||
sender_pre_balance = (
|
||||
pre_balances["sender"].get(expected_token, {}).get("balance", 0)
|
||||
)
|
||||
sender_post_balance = (
|
||||
post_balances["sender"].get(expected_token, {}).get("balance", 0)
|
||||
)
|
||||
dest_pre_balance = (
|
||||
pre_balances["destination"]
|
||||
.get(expected_token, {})
|
||||
.get("balance", 0)
|
||||
)
|
||||
dest_post_balance = (
|
||||
post_balances["destination"]
|
||||
.get(expected_token, {})
|
||||
.get("balance", 0)
|
||||
)
|
||||
|
||||
sender_token_change = float(sender_post_balance) - float(
|
||||
sender_pre_balance
|
||||
)
|
||||
dest_token_change = float(dest_post_balance) - float(dest_pre_balance)
|
||||
|
||||
# For ERC-20, expect exact amounts (no gas costs in token)
|
||||
if (
|
||||
abs(sender_token_change + expected_amount_float) < 1e-6
|
||||
and abs(dest_token_change - expected_amount_float) < 1e-6
|
||||
):
|
||||
return 0.5
|
||||
|
||||
return 0.0 # Transfer amounts don't match expectations
|
||||
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
def _analyze_thinking_quality(self, response: str) -> float:
|
||||
"""Evaluate thinking tag quality with max 0.1 points, negative for missing thinking"""
|
||||
thinking_score = 0.0
|
||||
|
||||
# Check for thinking tags
|
||||
if "<think>" not in response or "</think>" not in response:
|
||||
return -0.2 # Penalty for no thinking tags
|
||||
|
||||
# Extract thinking content
|
||||
try:
|
||||
thinking_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
|
||||
if not thinking_match:
|
||||
return -0.2 # No thinking content found
|
||||
|
||||
thinking_content = thinking_match.group(1).strip()
|
||||
if not thinking_content:
|
||||
return -0.1 # Empty thinking tags
|
||||
|
||||
# Basic quality assessment for positive score (max 0.1)
|
||||
word_count = len(thinking_content.split())
|
||||
|
||||
# Award points based on thinking depth
|
||||
if word_count >= 50: # Substantial thinking
|
||||
thinking_score += 0.1
|
||||
elif word_count >= 20: # Moderate thinking
|
||||
thinking_score += 0.05
|
||||
elif word_count >= 5: # Minimal thinking
|
||||
thinking_score += 0.02
|
||||
|
||||
return thinking_score
|
||||
|
||||
except Exception:
|
||||
return -0.1 # Error in processing thinking
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluation method - could implement portfolio performance tracking"""
|
||||
return
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
cleanup_blockchain(self.blockchain)
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[EVMEnvConfig, List[APIServerConfig]]:
|
||||
"""Initialize configuration for EVM environment"""
|
||||
# pydantic-settings automatically loads from YAML file
|
||||
env_config = EVMEnvConfig(
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
group_size=4,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=500,
|
||||
batch_size=16,
|
||||
steps_per_eval=50,
|
||||
max_token_length=2048,
|
||||
wandb_name="evm-agent",
|
||||
anvil_config_path="configs/token_transfers.yaml",
|
||||
)
|
||||
|
||||
# API server configuration
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="gpt-4o-mini",
|
||||
base_url=None, # Use OpenAI directly
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
num_requests_for_eval=64,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
EVMEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue