import os import random import re from typing import Dict, List, Optional, Tuple, TypedDict, Union import yaml from datasets import load_dataset from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics import jaccard_score from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, ) from atroposlib.type_definitions import Item, number from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer os.environ["OPENAI_API_KEY"] = "x" system_prompt = """ You are a systematic, deep-reasoning AI trained to solve detection engineering problems by constructing Sigma rules. You must first **think carefully** through the problem using structured internal monologue enclosed in ... tags, and then provide your final Sigma rule as a YAML block enclosed in a LaTeX \\boxed{...} environment. **You MUST follow this exact output format:** Step-by-step reasoning, outlining how you arrived at the rule. Include relevant knowledge about Sigma syntax, threat detection context, and how you chose each field. \\boxed{ } **Important Rules:** - DO NOT skip the tags — all your thoughts must be inside them. - DO NOT skip the \\boxed{...} wrapper — your final YAML answer MUST go inside it. - DO NOT output anything outside the and \\boxed{} blocks. - The content inside \\boxed{} must be **pure YAML**, with **no extra formatting characters** (no bullets, backticks, emojis, or markdown) so it can be passed **directly** to `yaml.safe_load()`. - You are allocated a maximum of 2048 tokens — be detailed but concise. - Your final output must be valid YAML for a Sigma rule, with correct indentation and field names. - Use Sigma best practices: include `detection.condition`, `detection.selection`, `logsource`, and `timeframe` when appropriate. - Match the format and style of this example: Example: This rule is intended to detect potential lateral movement attempts by monitoring for excessive outbound connections. The condition 'selection | count(dst_ip) by src_ip > 10' flags when one source IP connects to more than 10 destination IPs. The log source is firewall logs, and the action is 'denied'. \\boxed{ detection: condition: selection | count(dst_ip) by src_ip > 10 selection: action: denied timeframe: 24h logsource: category: firewall } Only produce answers in this format. Think first, then answer clearly in YAML. Follow YAML syntax exactly. """ class SigmaRuleEntry(TypedDict): question: str answer: str class SigmaRuleEnv(BaseEnv): name = "sigmarule" def __init__( self, config: BaseEnvConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() # Add tracking for wandb visualizations self.rollouts_for_wandb = [] self.completion_lengths = [] @classmethod def config_init(cls) -> Tuple[BaseEnvConfig, List[APIServerConfig]]: env_config = BaseEnvConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", group_size=8, use_wandb=True, rollout_server_url="http://localhost:8000", total_steps=1000, batch_size=12, steps_per_eval=100, max_token_length=2048, wandb_name="gsm8k", ) server_configs = [ APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", base_url="http://localhost:9001/v1", # api_key="x", num_requests_for_eval=256, ), ] return env_config, server_configs async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Try to calculate percent_correct, pass if there's a division by zero try: wandb_metrics["train/percent_correct"] = sum( self.percent_correct_buffer ) / len(self.percent_correct_buffer) except ZeroDivisionError: # Skip if buffer is empty pass self.percent_correct_buffer = list() for item in self.eval_metrics: wandb_metrics[item[0]] = item[1] self.eval_metrics = list() # Call the parent method to handle the server metrics await super().wandb_log(wandb_metrics) async def setup(self): self.train = load_dataset( "mmaisel1/nous-rl-hackathon-sigma", split="train" ).shuffle(seed=42) test_data = load_dataset( "mmaisel1/nous-rl-hackathon-sigma", split="test" ).shuffle(seed=42) self.test = list() for item in test_data: self.test.append({"question": item["prompt"], "gold_answer": item["rule"]}) self.iter = 0 def save_checkpoint(self, step, data=None): if data is None: data = {} data["iter"] = self.iter super().save_checkpoint(step, data) def flatten_detection(self, det_dict): if not isinstance(det_dict, dict): return "" return " ".join(f"{k}:{v}" for k, v in det_dict.items()) def jaccard_similarity(self, str1, str2): vec = CountVectorizer(binary=True).fit([str1, str2]) v1 = vec.transform([str1]).toarray()[0] v2 = vec.transform([str2]).toarray()[0] print("vectors", v1, v2) return jaccard_score(v1, v2) if (v1.any() and v2.any()) else 0.0 async def rollout_and_score_eval(self, question: str, answer: str) -> number: completion = await self.server.chat_completion( messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": question}, ], n=1, max_tokens=self.config.max_token_length, temperature=0.0, split="eval", ) gold_parsed = parse( "\\boxed{" + answer + "}", extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) answer_parsed = parse( completion.choices[0].message.content.split("")[-1], extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", ) score = self.jaccard_similarity( self.flatten_detection(gold_parsed["detection"]), self.flatten_detection(answer_parsed["detection"]), ) return score async def evaluate(self, *args, **kwargs): eval_tasks = [] for item in self.test: eval_tasks.append( self.rollout_and_score_eval(item["question"], item["gold_answer"]) ) scores = await tqdm_asyncio.gather(*eval_tasks) self.eval_metrics.append(("eval/percent_correct", sum(scores) / len(scores))) async def collect_trajectories( self, item: SigmaRuleEntry ) -> Tuple[ScoredDataGroup, list[Item]]: user_message = {"role": "user", "content": item["prompt"]} gold_answer = ( "\\boxed{" + item["rule"].split("#")[-1].strip().replace(",", "") + "}" ) chat_completions = await self.server.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], n=self.config.group_size, max_tokens=self.config.max_token_length, ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): messages = ( {"role": "system", "content": system_prompt}, user_message, {"role": "assistant", "content": chat_completion.message.content}, ) to_score.append( { "messages": messages, "gold_answer": gold_answer, "finish_reason": chat_completion.finish_reason, } ) to_postprocess = await self.score(to_score) return to_postprocess, to_backlog async def score( self, rollout_group_data ) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: scores = ScoredDataGroup() scores["tokens"] = list() scores["masks"] = list() scores["scores"] = list() try: # gold_yaml = yaml.safe_load(rollout_group_data[0]["gold_answer"]) gold_det = rollout_group_data[0]["gold_answer"] except yaml.YAMLError: reward = 0 random.shuffle(rollout_group_data) for item in rollout_group_data: out_dict = tokenize_for_trainer( self.tokenizer, item["messages"], item["finish_reason"] ) tokens = out_dict["tokens"] masks = out_dict["masks"] boxed_match = re.search( r"\\boxed\{(.*)\}\s*$", item["messages"][-1]["content"], re.DOTALL ) try: if boxed_match: yaml_str = boxed_match.group(1) # gen_yaml = yaml.safe_load(yaml_str) # gen_det = self.flatten_detection(gen_yaml.get("detection", {})) reward = self.jaccard_similarity(gold_det, yaml_str) else: reward = 0 except Exception as e: print(e) reward = 0 print("GOLD ANSWER:", rollout_group_data[0]["gold_answer"]) print("GEN OUTPUT:", item["messages"][-1]["content"]) print("REWARD:", reward) # Optional: Add LLM-based score for semantic evaluation (not shown here) # if len([1 for i in masks if i != -100]) < 10: # continue scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(reward) # if len(scores["tokens"]) >= self.config.group_size: # break # Buffer for analysis for score in scores["scores"]: self.percent_correct_buffer.append(score) # Optional: Apply length penalty if all rewards are 1.0 if all(score >= 0.99 for score in scores["scores"]): # float-safe check token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) == 0: return None max_allowed_length = self.config.max_token_length length_threshold = max_allowed_length * 0.5 scores["scores"] = [] for length in token_lengths: if length <= length_threshold: scores["scores"].append(1.0) else: pct_range = (length - length_threshold) / ( max_allowed_length - length_threshold ) scores["scores"].append(1.0 - min(pct_range, 1.0)) # if all([scores["scores"][0] == score for score in scores["scores"]]): # return None return scores if scores["tokens"] else None async def get_next_item(self) -> SigmaRuleEntry: next_item = self.train[self.iter % len(self.train)] self.iter += 1 return next_item if __name__ == "__main__": SigmaRuleEnv.cli()