atropos/environments/community/cybersecurity_sigma/jaccard_reward_env.py
shannonsands b774e97215
Integrate subrahmanyam cybersecurity (#142)
* cybersecurity env for offline RL trajectories

* output file addition

* jsonl outputs

* code cleanup

* pulled out outputs and fixing .gitignore

* removed zip file

* gitignore typo fix

* Integrate cybersecurity Sigma rule generation environment

---------

Co-authored-by: Subrahmanyam Arunachalam <subrahmanyam.arunachalam@FVFGK0VTQ05P.local>
2025-05-28 08:41:51 +10:00

340 lines
12 KiB
Python

import os
import random
import re
from typing import Dict, List, Optional, Tuple, TypedDict, Union
import yaml
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import jaccard_score
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
APIServerConfig,
BaseEnv,
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.type_definitions import Item, number
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
os.environ["OPENAI_API_KEY"] = "x"
system_prompt = """
You are a systematic, deep-reasoning AI trained to solve detection engineering problems by
constructing Sigma rules. You must first **think carefully** through the problem using structured
internal monologue enclosed in <think>...</think> tags, and then provide your final Sigma rule as
a YAML block enclosed in a LaTeX \\boxed{...} environment.
**You MUST follow this exact output format:**
<think>
Step-by-step reasoning, outlining how you arrived at the rule. Include relevant knowledge about
Sigma syntax, threat detection context, and how you chose each field.
</think>
\\boxed{
<YAML Sigma Rule>
}
**Important Rules:**
- DO NOT skip the <think> tags — all your thoughts must be inside them.
- DO NOT skip the \\boxed{...} wrapper — your final YAML answer MUST go inside it.
- DO NOT output anything outside the <think> and \\boxed{} blocks.
- The content inside \\boxed{} must be **pure YAML**, with **no extra formatting characters**
(no bullets, backticks, emojis, or markdown) so it can be passed **directly** to `yaml.safe_load()`.
- You are allocated a maximum of 2048 tokens — be detailed but concise.
- Your final output must be valid YAML for a Sigma rule, with correct indentation and field names.
- Use Sigma best practices: include `detection.condition`, `detection.selection`,
`logsource`, and `timeframe` when appropriate.
- Match the format and style of this example:
Example:
<think>
This rule is intended to detect potential lateral movement attempts by monitoring for excessive
outbound connections. The condition 'selection | count(dst_ip) by src_ip > 10' flags when one
source IP connects to more than 10 destination IPs. The log source is firewall logs, and the
action is 'denied'.
</think>
\\boxed{
detection:
condition: selection | count(dst_ip) by src_ip > 10
selection:
action: denied
timeframe: 24h
logsource:
category: firewall
}
Only produce answers in this format. Think first, then answer clearly in YAML. Follow YAML syntax
exactly.
"""
class SigmaRuleEntry(TypedDict):
question: str
answer: str
class SigmaRuleEnv(BaseEnv):
name = "sigmarule"
def __init__(
self,
config: BaseEnvConfig,
server_configs: List[APIServerConfig],
slurm=True,
testing=False,
):
super().__init__(config, server_configs, slurm, testing)
self.percent_correct_buffer = list()
self.eval_metrics = list()
# Add tracking for wandb visualizations
self.rollouts_for_wandb = []
self.completion_lengths = []
@classmethod
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
env_config = BaseEnvConfig(
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
group_size=8,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
batch_size=12,
steps_per_eval=100,
max_token_length=2048,
wandb_name="gsm8k",
)
server_configs = [
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
base_url="http://localhost:9001/v1",
# api_key="x",
num_requests_for_eval=256,
),
]
return env_config, server_configs
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
if wandb_metrics is None:
wandb_metrics = {}
# Try to calculate percent_correct, pass if there's a division by zero
try:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
) / len(self.percent_correct_buffer)
except ZeroDivisionError:
# Skip if buffer is empty
pass
self.percent_correct_buffer = list()
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
# Call the parent method to handle the server metrics
await super().wandb_log(wandb_metrics)
async def setup(self):
self.train = load_dataset(
"mmaisel1/nous-rl-hackathon-sigma", split="train"
).shuffle(seed=42)
test_data = load_dataset(
"mmaisel1/nous-rl-hackathon-sigma", split="test"
).shuffle(seed=42)
self.test = list()
for item in test_data:
self.test.append({"question": item["prompt"], "gold_answer": item["rule"]})
self.iter = 0
def save_checkpoint(self, step, data=None):
if data is None:
data = {}
data["iter"] = self.iter
super().save_checkpoint(step, data)
def flatten_detection(self, det_dict):
if not isinstance(det_dict, dict):
return ""
return " ".join(f"{k}:{v}" for k, v in det_dict.items())
def jaccard_similarity(self, str1, str2):
vec = CountVectorizer(binary=True).fit([str1, str2])
v1 = vec.transform([str1]).toarray()[0]
v2 = vec.transform([str2]).toarray()[0]
print("vectors", v1, v2)
return jaccard_score(v1, v2) if (v1.any() and v2.any()) else 0.0
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
completion = await self.server.chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
],
n=1,
max_tokens=self.config.max_token_length,
temperature=0.0,
split="eval",
)
gold_parsed = parse(
"\\boxed{" + answer + "}",
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
answer_parsed = parse(
completion.choices[0].message.content.split("</think>")[-1],
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
score = self.jaccard_similarity(
self.flatten_detection(gold_parsed["detection"]),
self.flatten_detection(answer_parsed["detection"]),
)
return score
async def evaluate(self, *args, **kwargs):
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["question"], item["gold_answer"])
)
scores = await tqdm_asyncio.gather(*eval_tasks)
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
async def collect_trajectories(
self, item: SigmaRuleEntry
) -> Tuple[ScoredDataGroup, list[Item]]:
user_message = {"role": "user", "content": item["prompt"]}
gold_answer = (
"\\boxed{" + item["rule"].split("#")[-1].strip().replace(",", "") + "}"
)
chat_completions = await self.server.chat_completion(
messages=[{"role": "system", "content": system_prompt}, user_message],
n=self.config.group_size,
max_tokens=self.config.max_token_length,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
messages = (
{"role": "system", "content": system_prompt},
user_message,
{"role": "assistant", "content": chat_completion.message.content},
)
to_score.append(
{
"messages": messages,
"gold_answer": gold_answer,
"finish_reason": chat_completion.finish_reason,
}
)
to_postprocess = await self.score(to_score)
return to_postprocess, to_backlog
async def score(
self, rollout_group_data
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
scores = ScoredDataGroup()
scores["tokens"] = list()
scores["masks"] = list()
scores["scores"] = list()
try:
# gold_yaml = yaml.safe_load(rollout_group_data[0]["gold_answer"])
gold_det = rollout_group_data[0]["gold_answer"]
except yaml.YAMLError:
reward = 0
random.shuffle(rollout_group_data)
for item in rollout_group_data:
out_dict = tokenize_for_trainer(
self.tokenizer, item["messages"], item["finish_reason"]
)
tokens = out_dict["tokens"]
masks = out_dict["masks"]
boxed_match = re.search(
r"\\boxed\{(.*)\}\s*$", item["messages"][-1]["content"], re.DOTALL
)
try:
if boxed_match:
yaml_str = boxed_match.group(1)
# gen_yaml = yaml.safe_load(yaml_str)
# gen_det = self.flatten_detection(gen_yaml.get("detection", {}))
reward = self.jaccard_similarity(gold_det, yaml_str)
else:
reward = 0
except Exception as e:
print(e)
reward = 0
print("GOLD ANSWER:", rollout_group_data[0]["gold_answer"])
print("GEN OUTPUT:", item["messages"][-1]["content"])
print("REWARD:", reward)
# Optional: Add LLM-based score for semantic evaluation (not shown here)
# if len([1 for i in masks if i != -100]) < 10:
# continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
scores["scores"].append(reward)
# if len(scores["tokens"]) >= self.config.group_size:
# break
# Buffer for analysis
for score in scores["scores"]:
self.percent_correct_buffer.append(score)
# Optional: Apply length penalty if all rewards are 1.0
if all(score >= 0.99 for score in scores["scores"]): # float-safe check
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
return None
max_allowed_length = self.config.max_token_length
length_threshold = max_allowed_length * 0.5
scores["scores"] = []
for length in token_lengths:
if length <= length_threshold:
scores["scores"].append(1.0)
else:
pct_range = (length - length_threshold) / (
max_allowed_length - length_threshold
)
scores["scores"].append(1.0 - min(pct_range, 1.0))
# if all([scores["scores"][0] == score for score in scores["scores"]]):
# return None
return scores if scores["tokens"] else None
async def get_next_item(self) -> SigmaRuleEntry:
next_item = self.train[self.iter % len(self.train)]
self.iter += 1
return next_item
if __name__ == "__main__":
SigmaRuleEnv.cli()