mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Integrate subrahmanyam cybersecurity (#142)
* cybersecurity env for offline RL trajectories * output file addition * jsonl outputs * code cleanup * pulled out outputs and fixing .gitignore * removed zip file * gitignore typo fix * Integrate cybersecurity Sigma rule generation environment --------- Co-authored-by: Subrahmanyam Arunachalam <subrahmanyam.arunachalam@FVFGK0VTQ05P.local>
This commit is contained in:
parent
b33070f56b
commit
b774e97215
4 changed files with 890 additions and 0 deletions
139
environments/community/cybersecurity_sigma/README.md
Normal file
139
environments/community/cybersecurity_sigma/README.md
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# Cybersecurity Sigma Rule Generation Environment
|
||||
|
||||
This environment trains LLMs to generate semantically correct Sigma detection rules from threat-hunting prompts. It provides two different reward mechanisms for evaluating generated rules.
|
||||
|
||||
## Overview
|
||||
|
||||
The environment focuses on structured generation tasks where outputs must be valid YAML conforming to Sigma detection rule schemas. It includes two implementations with different reward functions:
|
||||
|
||||
1. **Jaccard Similarity Reward** (`jaccard_reward_env.py`) - Uses token-based similarity scoring
|
||||
2. **LLM Judge Reward** (`llm_judge_env.py`) - Uses LLM-based semantic evaluation
|
||||
|
||||
## Core Features
|
||||
|
||||
### Dataset Integration
|
||||
- Uses the `mmaisel1/nous-rl-hackathon-sigma` dataset from Hugging Face
|
||||
- Contains threat-hunting prompts paired with corresponding Sigma rules
|
||||
- Automatic train/test split with shuffling for reproducibility
|
||||
|
||||
### Structured Output Format
|
||||
- Enforces specific output format with `<think>...</think>` reasoning tags
|
||||
- Requires YAML output wrapped in LaTeX `\boxed{...}` environment
|
||||
- Validates YAML syntax and Sigma rule structure
|
||||
|
||||
### Dual Reward Mechanisms
|
||||
|
||||
#### Jaccard Similarity Scoring
|
||||
- Compares flattened key paths of gold and generated YAML under `detection:` section
|
||||
- Uses scikit-learn's Jaccard similarity for token-based matching
|
||||
- Tends to produce low and sparse rewards due to structural mismatches
|
||||
|
||||
#### LLM-as-a-Judge Scoring
|
||||
- Uses binary LLM evaluation for semantic equivalence assessment
|
||||
- Returns 1.0 if generated rule is functionally equivalent to gold standard
|
||||
- Provides higher-fidelity supervision even when structure varies
|
||||
|
||||
### Advanced Features
|
||||
- Length penalty system for overly verbose outputs
|
||||
- Comprehensive evaluation metrics tracking
|
||||
- W&B integration for experiment monitoring
|
||||
- Configurable token limits and batch sizes
|
||||
|
||||
## Technical Implementation
|
||||
|
||||
### Environment Configuration
|
||||
- **Model**: NousResearch/DeepHermes-3-Llama-3-3B-Preview
|
||||
- **Max Token Length**: 2048 tokens
|
||||
- **Group Size**: 8 completions per prompt
|
||||
- **Batch Size**: 12 items per batch
|
||||
- **Evaluation Frequency**: Every 100 steps
|
||||
|
||||
### System Prompt Structure
|
||||
The environment uses a detailed system prompt that:
|
||||
- Enforces structured reasoning with `<think>` tags
|
||||
- Requires YAML output in `\boxed{}` environment
|
||||
- Provides Sigma rule best practices and examples
|
||||
- Specifies exact formatting requirements for parser compatibility
|
||||
|
||||
### Scoring Pipeline
|
||||
1. **Extraction**: Parse YAML from `\boxed{}` wrapper using regex
|
||||
2. **Validation**: Attempt YAML parsing and structure validation
|
||||
3. **Evaluation**: Apply either Jaccard similarity or LLM judge scoring
|
||||
4. **Aggregation**: Collect scores for batch-level reward computation
|
||||
|
||||
## Setup and Usage
|
||||
|
||||
### Environment Variables
|
||||
```bash
|
||||
export OPENAI_API_KEY="your-openai-api-key" # For LLM judge (optional)
|
||||
export NOUS_API_KEY="your-nous-api-key" # For model inference
|
||||
```
|
||||
|
||||
### Command Line Usage
|
||||
```bash
|
||||
# Jaccard similarity reward
|
||||
python environments/community/cybersecurity_sigma/jaccard_reward_env.py
|
||||
|
||||
# LLM judge reward
|
||||
python environments/community/cybersecurity_sigma/llm_judge_env.py
|
||||
```
|
||||
|
||||
### Dependencies
|
||||
- `datasets` - Hugging Face dataset loading
|
||||
- `scikit-learn` - Jaccard similarity computation (jaccard_reward_env only)
|
||||
- `latex2sympy2_extended` - LaTeX parsing utilities
|
||||
- `math_verify` - YAML extraction from LaTeX boxes
|
||||
- `openai` - LLM judge API calls (llm_judge_env only)
|
||||
|
||||
## Research Applications
|
||||
|
||||
### Cybersecurity Training
|
||||
- Train models to understand threat detection patterns
|
||||
- Generate rules for various attack vectors and techniques
|
||||
- Develop automated threat hunting capabilities
|
||||
|
||||
### Structured Generation Research
|
||||
- Study LLM performance on constrained output formats
|
||||
- Compare token-based vs. semantic evaluation methods
|
||||
- Investigate reasoning quality in cybersecurity domains
|
||||
|
||||
### Evaluation Methodology Development
|
||||
- Benchmark different reward function approaches
|
||||
- Analyze correlation between structural and semantic correctness
|
||||
- Develop better automated evaluation metrics for domain-specific tasks
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Jaccard Similarity Results
|
||||
- **Typical Rewards**: 0.1-0.3 range due to structural sensitivity
|
||||
- **Strengths**: Fast computation, deterministic scoring
|
||||
- **Limitations**: Sensitive to formatting differences, low reward density
|
||||
|
||||
### LLM Judge Results
|
||||
- **Typical Rewards**: Binary 0.0/1.0 with higher success rates
|
||||
- **Strengths**: Semantic understanding, format flexibility
|
||||
- **Limitations**: API latency, potential inconsistency, cost considerations
|
||||
|
||||
## Example Outputs
|
||||
|
||||
### Input Prompt
|
||||
```
|
||||
DotNET Assembly DLL Loaded Via Office Application: Detects any assembly DLL being loaded by an Office Product
|
||||
```
|
||||
|
||||
### Expected Sigma Rule Format
|
||||
```yaml
|
||||
detection:
|
||||
condition: selection
|
||||
selection:
|
||||
process_name:
|
||||
- excel.exe
|
||||
- word.exe
|
||||
- powerpnt.exe
|
||||
dll_loaded: "*.dll"
|
||||
logsource:
|
||||
category: process
|
||||
product: windows
|
||||
```
|
||||
|
||||
The environment provides a robust framework for training LLMs on cybersecurity detection rule generation with flexible evaluation mechanisms suited for different research objectives.
|
||||
340
environments/community/cybersecurity_sigma/jaccard_reward_env.py
Normal file
340
environments/community/cybersecurity_sigma/jaccard_reward_env.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.metrics import jaccard_score
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "x"
|
||||
|
||||
system_prompt = """
|
||||
You are a systematic, deep-reasoning AI trained to solve detection engineering problems by
|
||||
constructing Sigma rules. You must first **think carefully** through the problem using structured
|
||||
internal monologue enclosed in <think>...</think> tags, and then provide your final Sigma rule as
|
||||
a YAML block enclosed in a LaTeX \\boxed{...} environment.
|
||||
|
||||
**You MUST follow this exact output format:**
|
||||
<think>
|
||||
Step-by-step reasoning, outlining how you arrived at the rule. Include relevant knowledge about
|
||||
Sigma syntax, threat detection context, and how you chose each field.
|
||||
</think>
|
||||
|
||||
\\boxed{
|
||||
<YAML Sigma Rule>
|
||||
}
|
||||
|
||||
**Important Rules:**
|
||||
- DO NOT skip the <think> tags — all your thoughts must be inside them.
|
||||
- DO NOT skip the \\boxed{...} wrapper — your final YAML answer MUST go inside it.
|
||||
- DO NOT output anything outside the <think> and \\boxed{} blocks.
|
||||
- The content inside \\boxed{} must be **pure YAML**, with **no extra formatting characters**
|
||||
(no bullets, backticks, emojis, or markdown) so it can be passed **directly** to `yaml.safe_load()`.
|
||||
- You are allocated a maximum of 2048 tokens — be detailed but concise.
|
||||
- Your final output must be valid YAML for a Sigma rule, with correct indentation and field names.
|
||||
- Use Sigma best practices: include `detection.condition`, `detection.selection`,
|
||||
`logsource`, and `timeframe` when appropriate.
|
||||
- Match the format and style of this example:
|
||||
|
||||
Example:
|
||||
<think>
|
||||
This rule is intended to detect potential lateral movement attempts by monitoring for excessive
|
||||
outbound connections. The condition 'selection | count(dst_ip) by src_ip > 10' flags when one
|
||||
source IP connects to more than 10 destination IPs. The log source is firewall logs, and the
|
||||
action is 'denied'.
|
||||
</think>
|
||||
|
||||
\\boxed{
|
||||
detection:
|
||||
condition: selection | count(dst_ip) by src_ip > 10
|
||||
selection:
|
||||
action: denied
|
||||
timeframe: 24h
|
||||
logsource:
|
||||
category: firewall
|
||||
}
|
||||
|
||||
Only produce answers in this format. Think first, then answer clearly in YAML. Follow YAML syntax
|
||||
exactly.
|
||||
"""
|
||||
|
||||
|
||||
class SigmaRuleEntry(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class SigmaRuleEnv(BaseEnv):
|
||||
|
||||
name = "sigmarule"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
# api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = load_dataset(
|
||||
"mmaisel1/nous-rl-hackathon-sigma", split="train"
|
||||
).shuffle(seed=42)
|
||||
test_data = load_dataset(
|
||||
"mmaisel1/nous-rl-hackathon-sigma", split="test"
|
||||
).shuffle(seed=42)
|
||||
self.test = list()
|
||||
for item in test_data:
|
||||
self.test.append({"question": item["prompt"], "gold_answer": item["rule"]})
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
def flatten_detection(self, det_dict):
|
||||
if not isinstance(det_dict, dict):
|
||||
return ""
|
||||
return " ".join(f"{k}:{v}" for k, v in det_dict.items())
|
||||
|
||||
def jaccard_similarity(self, str1, str2):
|
||||
vec = CountVectorizer(binary=True).fit([str1, str2])
|
||||
v1 = vec.transform([str1]).toarray()[0]
|
||||
v2 = vec.transform([str2]).toarray()[0]
|
||||
print("vectors", v1, v2)
|
||||
return jaccard_score(v1, v2) if (v1.any() and v2.any()) else 0.0
|
||||
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
gold_parsed = parse(
|
||||
"\\boxed{" + answer + "}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
answer_parsed = parse(
|
||||
completion.choices[0].message.content.split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
score = self.jaccard_similarity(
|
||||
self.flatten_detection(gold_parsed["detection"]),
|
||||
self.flatten_detection(answer_parsed["detection"]),
|
||||
)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: SigmaRuleEntry
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["prompt"]}
|
||||
gold_answer = (
|
||||
"\\boxed{" + item["rule"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
)
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": gold_answer,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
try:
|
||||
# gold_yaml = yaml.safe_load(rollout_group_data[0]["gold_answer"])
|
||||
gold_det = rollout_group_data[0]["gold_answer"]
|
||||
except yaml.YAMLError:
|
||||
reward = 0
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
boxed_match = re.search(
|
||||
r"\\boxed\{(.*)\}\s*$", item["messages"][-1]["content"], re.DOTALL
|
||||
)
|
||||
try:
|
||||
if boxed_match:
|
||||
yaml_str = boxed_match.group(1)
|
||||
# gen_yaml = yaml.safe_load(yaml_str)
|
||||
# gen_det = self.flatten_detection(gen_yaml.get("detection", {}))
|
||||
reward = self.jaccard_similarity(gold_det, yaml_str)
|
||||
else:
|
||||
reward = 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
reward = 0
|
||||
|
||||
print("GOLD ANSWER:", rollout_group_data[0]["gold_answer"])
|
||||
print("GEN OUTPUT:", item["messages"][-1]["content"])
|
||||
print("REWARD:", reward)
|
||||
|
||||
# Optional: Add LLM-based score for semantic evaluation (not shown here)
|
||||
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
|
||||
# Buffer for analysis
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(score)
|
||||
|
||||
# Optional: Apply length penalty if all rewards are 1.0
|
||||
if all(score >= 0.99 for score in scores["scores"]): # float-safe check
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
return None
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
pct_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
scores["scores"].append(1.0 - min(pct_range, 1.0))
|
||||
|
||||
# if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
# return None
|
||||
|
||||
return scores if scores["tokens"] else None
|
||||
|
||||
async def get_next_item(self) -> SigmaRuleEntry:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SigmaRuleEnv.cli()
|
||||
368
environments/community/cybersecurity_sigma/llm_judge_env.py
Normal file
368
environments/community/cybersecurity_sigma/llm_judge_env.py
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.type_definitions import Item, number
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "x"
|
||||
|
||||
system_prompt = """
|
||||
You are a systematic, deep-reasoning AI trained to solve detection engineering problems by
|
||||
constructing Sigma rules. You must first **think carefully** through the problem using structured
|
||||
internal monologue enclosed in <think>...</think> tags, and then provide your final Sigma rule as
|
||||
a YAML block enclosed in a LaTeX \\boxed{...} environment.
|
||||
|
||||
**You MUST follow this exact output format:**
|
||||
<think>
|
||||
Step-by-step reasoning, outlining how you arrived at the rule. Include relevant knowledge about
|
||||
Sigma syntax, threat detection context, and how you chose each field.
|
||||
</think>
|
||||
|
||||
\\boxed{
|
||||
<YAML Sigma Rule>
|
||||
}
|
||||
|
||||
**Important Rules:**
|
||||
- DO NOT skip the <think> tags — all your thoughts must be inside them.
|
||||
- DO NOT skip the \\boxed{...} wrapper — your final YAML answer MUST go inside it.
|
||||
- DO NOT output anything outside the <think> and \\boxed{} blocks.
|
||||
- The content inside \\boxed{} must be **pure YAML**, with **no extra formatting characters**
|
||||
(no bullets, backticks, emojis, or markdown) so it can be passed **directly** to `yaml.safe_load()`.
|
||||
- You are allocated a maximum of 2048 tokens — be detailed but concise.
|
||||
- Your final output must be valid YAML for a Sigma rule, with correct indentation and field names.
|
||||
- Use Sigma best practices: include `detection.condition`, `detection.selection`,
|
||||
`logsource`, and `timeframe` when appropriate.
|
||||
- Match the format and style of this example:
|
||||
|
||||
Example:
|
||||
<think>
|
||||
This rule is intended to detect potential lateral movement attempts by monitoring for excessive
|
||||
outbound connections. The condition 'selection | count(dst_ip) by src_ip > 10' flags when one
|
||||
source IP connects to more than 10 destination IPs. The log source is firewall logs, and the
|
||||
action is 'denied'.
|
||||
</think>
|
||||
|
||||
\\boxed{
|
||||
detection:
|
||||
condition: selection | count(dst_ip) by src_ip > 10
|
||||
selection:
|
||||
action: denied
|
||||
timeframe: 24h
|
||||
logsource:
|
||||
category: firewall
|
||||
}
|
||||
|
||||
Only produce answers in this format. Think first, then answer clearly in YAML. Follow YAML syntax
|
||||
exactly.
|
||||
"""
|
||||
|
||||
|
||||
class SigmaRuleEntry(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class SigmaRuleEnv(BaseEnv):
|
||||
|
||||
name = "llm_judge_sigmarule"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm=True,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.percent_correct_buffer = list()
|
||||
self.eval_metrics = list()
|
||||
# Add tracking for wandb visualizations
|
||||
self.rollouts_for_wandb = []
|
||||
self.completion_lengths = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]:
|
||||
env_config = BaseEnvConfig(
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
group_size=8,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
batch_size=12,
|
||||
steps_per_eval=100,
|
||||
max_token_length=2048,
|
||||
wandb_name="gsm8k",
|
||||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
base_url="http://localhost:9001/v1",
|
||||
# api_key="x",
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Try to calculate percent_correct, pass if there's a division by zero
|
||||
try:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
) / len(self.percent_correct_buffer)
|
||||
except ZeroDivisionError:
|
||||
# Skip if buffer is empty
|
||||
pass
|
||||
|
||||
self.percent_correct_buffer = list()
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
# Call the parent method to handle the server metrics
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def setup(self):
|
||||
self.train = load_dataset(
|
||||
"mmaisel1/nous-rl-hackathon-sigma", split="train"
|
||||
).shuffle(seed=42)
|
||||
test_data = load_dataset(
|
||||
"mmaisel1/nous-rl-hackathon-sigma", split="test"
|
||||
).shuffle(seed=42)
|
||||
self.test = list()
|
||||
for item in test_data:
|
||||
self.test.append({"question": item["prompt"], "gold_answer": item["rule"]})
|
||||
self.iter = 0
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
if data is None:
|
||||
data = {}
|
||||
data["iter"] = self.iter
|
||||
super().save_checkpoint(step, data)
|
||||
|
||||
async def llm_judge_similarity(
|
||||
self, gold_yaml_str: str, gen_yaml_str: str
|
||||
) -> float:
|
||||
"""
|
||||
Uses an LLM to decide whether the generated Sigma rule is semantically equivalent
|
||||
to the gold answer. Returns 1.0 if yes, 0.0 if no.
|
||||
"""
|
||||
prompt = f"""
|
||||
You are an expert in cybersecurity and YAML-based Sigma rules. Given a gold standard
|
||||
Sigma rule and a generated rule, determine if the generated rule is functionally
|
||||
equivalent and correct.
|
||||
|
||||
Only respond with "1" if the generated Sigma rule is correct and matches the intent
|
||||
and structure of the gold rule.
|
||||
Otherwise, respond with "0".
|
||||
|
||||
GOLD RULE:
|
||||
{gold_yaml_str}
|
||||
|
||||
GENERATED RULE:
|
||||
{gen_yaml_str}
|
||||
|
||||
Answer with only one character: "1" or "0".
|
||||
"""
|
||||
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
reply = completion.choices[0].message.content.strip()
|
||||
return 1.0 if reply == "1" else 0.0
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM Judge error: {e}")
|
||||
return 0.0
|
||||
|
||||
async def rollout_and_score_eval(self, question: str, answer: str) -> number:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
gold_parsed = parse(
|
||||
"\\boxed{" + answer + "}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
answer_parsed = parse(
|
||||
completion.choices[0].message.content.split("</think>")[-1],
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
score = self.llm_judge_similarity(
|
||||
gold_parsed["detection"],
|
||||
answer_parsed["detection"],
|
||||
)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
)
|
||||
scores = await tqdm_asyncio.gather(*eval_tasks)
|
||||
self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores)))
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: SigmaRuleEntry
|
||||
) -> Tuple[ScoredDataGroup, list[Item]]:
|
||||
user_message = {"role": "user", "content": item["prompt"]}
|
||||
gold_answer = (
|
||||
"\\boxed{" + item["rule"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
)
|
||||
|
||||
chat_completions = await self.server.chat_completion(
|
||||
messages=[{"role": "system", "content": system_prompt}, user_message],
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
{"role": "assistant", "content": chat_completion.message.content},
|
||||
)
|
||||
to_score.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"gold_answer": gold_answer,
|
||||
"finish_reason": chat_completion.finish_reason,
|
||||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data
|
||||
) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]:
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
scores["scores"] = list()
|
||||
|
||||
try:
|
||||
# gold_yaml = yaml.safe_load(rollout_group_data[0]["gold_answer"])
|
||||
gold_det = rollout_group_data[0]["gold_answer"]
|
||||
except yaml.YAMLError:
|
||||
reward = 0
|
||||
|
||||
random.shuffle(rollout_group_data)
|
||||
|
||||
for item in rollout_group_data:
|
||||
out_dict = tokenize_for_trainer(
|
||||
self.tokenizer, item["messages"], item["finish_reason"]
|
||||
)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
boxed_match = re.search(
|
||||
r"\\boxed\{(.*)\}\s*$", item["messages"][-1]["content"], re.DOTALL
|
||||
)
|
||||
try:
|
||||
if boxed_match:
|
||||
yaml_str = boxed_match.group(1)
|
||||
# gen_yaml = yaml.safe_load(yaml_str)
|
||||
# gen_det = self.flatten_detection(gen_yaml.get("detection", {}))
|
||||
reward = await self.llm_judge_similarity(gold_det, yaml_str)
|
||||
else:
|
||||
reward = 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
reward = 0
|
||||
|
||||
print("GOLD ANSWER:", rollout_group_data[0]["gold_answer"])
|
||||
print("GEN OUTPUT:", item["messages"][-1]["content"])
|
||||
print("REWARD:", reward)
|
||||
|
||||
# Optional: Add LLM-based score for semantic evaluation (not shown here)
|
||||
|
||||
# if len([1 for i in masks if i != -100]) < 10:
|
||||
# continue
|
||||
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(reward)
|
||||
|
||||
# if len(scores["tokens"]) >= self.config.group_size:
|
||||
# break
|
||||
|
||||
# Buffer for analysis
|
||||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(score)
|
||||
|
||||
# Optional: Apply length penalty if all rewards are 1.0
|
||||
if all(score >= 0.99 for score in scores["scores"]): # float-safe check
|
||||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
return None
|
||||
max_allowed_length = self.config.max_token_length
|
||||
length_threshold = max_allowed_length * 0.5
|
||||
scores["scores"] = []
|
||||
for length in token_lengths:
|
||||
if length <= length_threshold:
|
||||
scores["scores"].append(1.0)
|
||||
else:
|
||||
pct_range = (length - length_threshold) / (
|
||||
max_allowed_length - length_threshold
|
||||
)
|
||||
scores["scores"].append(1.0 - min(pct_range, 1.0))
|
||||
|
||||
# if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
# return None
|
||||
|
||||
return scores if scores["tokens"] else None
|
||||
|
||||
async def get_next_item(self) -> SigmaRuleEntry:
|
||||
next_item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return next_item
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SigmaRuleEnv.cli()
|
||||
Loading…
Add table
Add a link
Reference in a new issue