fix: align ScamOrRugEnv with BaseEnv API and add wandb logging

This commit is contained in:
kokoron 2026-03-31 08:29:02 +00:00
parent 9a0554ddc9
commit 551cc7187d

View file

@ -1,7 +1,10 @@
import random
import re
from dataclasses import dataclass
from typing import Optional
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
from atroposlib.envs.server_handling.server_baseline import ServerBaseline
from atroposlib.type_definitions import Item
VALID_BURN_ADDRESSES = {
@ -14,9 +17,9 @@ VALID_BURN_ADDRESSES = {
class ScamOrRugConfig(BaseEnvConfig):
tokenizer_name: str = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
group_size: int = 8
max_token_len: int = 1024
num_rollouts: int = 256
num_iterations: int = 1024
max_token_length: int = 1024
num_rollouts_to_keep: int = 32
total_steps: int = 1000
SYSTEM_PROMPT = """You are an on-chain analyst from the perspective of an average Web3 user trying to protect themselves from scams and rug pulls.
@ -57,7 +60,6 @@ def generate_fake_burn_address() -> str:
def generate_token_data(label: str) -> dict:
if label == "SCAM":
supply = random.randint(1_000_000, 1_000_000_000_000)
cluster_pct = round(random.uniform(55, 92), 2)
@ -108,7 +110,7 @@ def generate_token_data(label: str) -> dict:
"can_sell": True,
"burn_address": random.choice([
generate_fake_burn_address(),
"0x0000000000000000000000000000000000000000"
"0x0000000000000000000000000000000000000000",
]),
"has_recover_function": random.choice([True, False]),
"has_mint_function": random.choice([True, False]),
@ -204,12 +206,12 @@ def score_response(response: str, data: dict, true_label: str) -> float:
if classification == true_label:
score += 0.4
elif (
(true_label == "SCAM" and classification == "RUG_RISK") or
(true_label == "RUG_RISK" and classification == "SCAM")
(true_label == "SCAM" and classification == "RUG_RISK")
or (true_label == "RUG_RISK" and classification == "SCAM")
):
score += 0.1
# 2. Reasoning quality (0.3) — checks across all 7 dimensions
# 2. Reasoning quality (0.3)
keywords = {
"SCAM": ["cluster", "mint", "tax", "sell", "honeypot", "burn", "wash", "fake", "recover"],
"RUG_RISK": ["cluster", "lp", "lock", "tax", "risk", "upgrade", "dev"],
@ -243,17 +245,26 @@ def score_response(response: str, data: dict, true_label: str) -> float:
class ScamOrRugEnv(BaseEnv):
name = "scam_or_rug_onchain"
def __init__(self, config: ScamOrRugConfig, **kwargs):
super().__init__(config, **kwargs)
self.labels = ["SCAM", "RUG_RISK", "LEGITIMATE"]
self.percent_correct_buffer = []
@classmethod
def config_init(cls) -> ScamOrRugConfig:
return ScamOrRugConfig()
def config_init(cls):
return ScamOrRugConfig(), ServerBaseline()
async def setup(self):
pass
async def get_next_item(self) -> Item:
label = random.choice(self.labels)
return (label,)
async def collect_trajectories(self, item) -> tuple:
label = random.choice(self.labels)
label = item[0]
data = generate_token_data(label)
prompt = format_prompt(data)
@ -265,7 +276,7 @@ class ScamOrRugEnv(BaseEnv):
completions = await self.server.completion(
messages=messages,
n=self.config.group_size,
max_tokens=self.config.max_token_len,
max_tokens=self.config.max_token_length,
)
scored = ScoredDataGroup()
@ -276,16 +287,45 @@ class ScamOrRugEnv(BaseEnv):
for completion in completions.choices:
response = completion.message.content
reward = score_response(response, data, label)
tokens, masks = self.tokenize_for_training(messages, response)
# tokenize
full_text = self.tokenizer.apply_chat_template(
messages + [{"role": "assistant", "content": response}],
tokenize=True,
add_generation_prompt=False,
)
prompt_text = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
)
tokens = full_text
masks = [-100] * len(prompt_text) + full_text[len(prompt_text):]
scored["tokens"].append(tokens)
scored["masks"].append(masks)
scored["scores"].append(reward)
return scored, {}
# track accuracy
response_upper = response.upper()
for lbl in ["SCAM", "RUG_RISK", "LEGITIMATE"]:
if f"CLASSIFICATION: {lbl}" in response_upper:
self.percent_correct_buffer.append(1.0 if lbl == label else 0.0)
break
async def get_next_item(self) -> Item:
label = random.choice(self.labels)
return (label,)
return scored, []
async def wandb_log(self, wandb_metrics: Optional[dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
if self.percent_correct_buffer:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
self.percent_correct_buffer = []
await super().wandb_log(wandb_metrics)
async def evaluate(self, *args, **kwargs):
pass