mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fix: align ScamOrRugEnv with BaseEnv API and add wandb logging
This commit is contained in:
parent
9a0554ddc9
commit
551cc7187d
1 changed files with 57 additions and 17 deletions
|
|
@ -1,7 +1,10 @@
|
|||
import random
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_baseline import ServerBaseline
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
VALID_BURN_ADDRESSES = {
|
||||
|
|
@ -14,9 +17,9 @@ VALID_BURN_ADDRESSES = {
|
|||
class ScamOrRugConfig(BaseEnvConfig):
|
||||
tokenizer_name: str = "NousResearch/DeepHermes-3-Llama-3-8B-Preview"
|
||||
group_size: int = 8
|
||||
max_token_len: int = 1024
|
||||
num_rollouts: int = 256
|
||||
num_iterations: int = 1024
|
||||
max_token_length: int = 1024
|
||||
num_rollouts_to_keep: int = 32
|
||||
total_steps: int = 1000
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """You are an on-chain analyst from the perspective of an average Web3 user trying to protect themselves from scams and rug pulls.
|
||||
|
|
@ -57,7 +60,6 @@ def generate_fake_burn_address() -> str:
|
|||
|
||||
|
||||
def generate_token_data(label: str) -> dict:
|
||||
|
||||
if label == "SCAM":
|
||||
supply = random.randint(1_000_000, 1_000_000_000_000)
|
||||
cluster_pct = round(random.uniform(55, 92), 2)
|
||||
|
|
@ -108,7 +110,7 @@ def generate_token_data(label: str) -> dict:
|
|||
"can_sell": True,
|
||||
"burn_address": random.choice([
|
||||
generate_fake_burn_address(),
|
||||
"0x0000000000000000000000000000000000000000"
|
||||
"0x0000000000000000000000000000000000000000",
|
||||
]),
|
||||
"has_recover_function": random.choice([True, False]),
|
||||
"has_mint_function": random.choice([True, False]),
|
||||
|
|
@ -204,12 +206,12 @@ def score_response(response: str, data: dict, true_label: str) -> float:
|
|||
if classification == true_label:
|
||||
score += 0.4
|
||||
elif (
|
||||
(true_label == "SCAM" and classification == "RUG_RISK") or
|
||||
(true_label == "RUG_RISK" and classification == "SCAM")
|
||||
(true_label == "SCAM" and classification == "RUG_RISK")
|
||||
or (true_label == "RUG_RISK" and classification == "SCAM")
|
||||
):
|
||||
score += 0.1
|
||||
|
||||
# 2. Reasoning quality (0.3) — checks across all 7 dimensions
|
||||
# 2. Reasoning quality (0.3)
|
||||
keywords = {
|
||||
"SCAM": ["cluster", "mint", "tax", "sell", "honeypot", "burn", "wash", "fake", "recover"],
|
||||
"RUG_RISK": ["cluster", "lp", "lock", "tax", "risk", "upgrade", "dev"],
|
||||
|
|
@ -243,17 +245,26 @@ def score_response(response: str, data: dict, true_label: str) -> float:
|
|||
|
||||
|
||||
class ScamOrRugEnv(BaseEnv):
|
||||
name = "scam_or_rug_onchain"
|
||||
|
||||
def __init__(self, config: ScamOrRugConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.labels = ["SCAM", "RUG_RISK", "LEGITIMATE"]
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> ScamOrRugConfig:
|
||||
return ScamOrRugConfig()
|
||||
def config_init(cls):
|
||||
return ScamOrRugConfig(), ServerBaseline()
|
||||
|
||||
async def setup(self):
|
||||
pass
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
label = random.choice(self.labels)
|
||||
return (label,)
|
||||
|
||||
async def collect_trajectories(self, item) -> tuple:
|
||||
label = random.choice(self.labels)
|
||||
label = item[0]
|
||||
data = generate_token_data(label)
|
||||
prompt = format_prompt(data)
|
||||
|
||||
|
|
@ -265,7 +276,7 @@ class ScamOrRugEnv(BaseEnv):
|
|||
completions = await self.server.completion(
|
||||
messages=messages,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_len,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
|
||||
scored = ScoredDataGroup()
|
||||
|
|
@ -276,16 +287,45 @@ class ScamOrRugEnv(BaseEnv):
|
|||
for completion in completions.choices:
|
||||
response = completion.message.content
|
||||
reward = score_response(response, data, label)
|
||||
tokens, masks = self.tokenize_for_training(messages, response)
|
||||
|
||||
# tokenize
|
||||
full_text = self.tokenizer.apply_chat_template(
|
||||
messages + [{"role": "assistant", "content": response}],
|
||||
tokenize=True,
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
tokens = full_text
|
||||
masks = [-100] * len(prompt_text) + full_text[len(prompt_text):]
|
||||
|
||||
scored["tokens"].append(tokens)
|
||||
scored["masks"].append(masks)
|
||||
scored["scores"].append(reward)
|
||||
|
||||
return scored, {}
|
||||
# track accuracy
|
||||
response_upper = response.upper()
|
||||
for lbl in ["SCAM", "RUG_RISK", "LEGITIMATE"]:
|
||||
if f"CLASSIFICATION: {lbl}" in response_upper:
|
||||
self.percent_correct_buffer.append(1.0 if lbl == label else 0.0)
|
||||
break
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
label = random.choice(self.labels)
|
||||
return (label,)
|
||||
return scored, []
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.percent_correct_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue