import random from typing import Dict, List, Optional, Tuple import numpy as np import pennylane as qml from datasets import load_dataset from pydantic import Field from transformers import AutoTokenizer from atroposlib.envs.base import ( APIServerConfig, BaseEnv, BaseEnvConfig, ScoredDataGroup, ) from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer class QuantumHybridConfig(BaseEnvConfig): """Configuration for the Quantum-Classical Hybrid Environment.""" # Quantum circuit parameters n_qubits: int = Field(8, description="Number of qubits in the quantum circuit") n_layers: int = Field( 3, description="Number of quantum layers to use in the circuit" ) # Dataset parameters dataset_name: str = Field( "wikitext", description="Dataset to use for training/evaluation" ) dataset_config: str = Field( "wikitext-2-raw-v1", description="Dataset configuration" ) sequence_length: int = Field(256, description="Length of sequences to process") # Base model parameters base_model_name: str = Field( "gpt2", description="Base model for hybrid experiments" ) # Training parameters perplexity_weight: float = Field( 0.7, description="Weight for perplexity in scoring" ) quantum_weight: float = Field( 0.3, description="Weight for quantum-specific metrics" ) class QuantumTextAnalyzer: """Standalone quantum analyzer for text coherence measurement.""" def __init__(self, n_qubits=6): self.n_qubits = n_qubits self.dev = qml.device("default.qubit", wires=n_qubits) @qml.qnode(self.dev) def text_analysis_circuit(text_features): # Embed text features as quantum states for i in range(self.n_qubits): qml.RY(text_features[i], wires=i) # Create entanglement patterns for i in range(self.n_qubits - 1): qml.CNOT(wires=[i, i + 1]) # Ring closure if self.n_qubits > 1: qml.CNOT(wires=[self.n_qubits - 1, 0]) # Additional entanglement for complex analysis for i in range(0, self.n_qubits - 2, 2): qml.CNOT(wires=[i, i + 2]) # Measure coherence return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)] self.circuit = text_analysis_circuit def analyze_text(self, text: str) -> float: """Analyze text and return quantum coherence score.""" try: # Extract text features text_len = min(len(text), 200) / 200.0 word_count = len(text.split()) / 100.0 if text.split() else 0.0 char_diversity = len(set(text.lower())) / 26.0 if text else 0.0 avg_word_len = ( np.mean([len(word) for word in text.split()]) / 15.0 if text.split() else 0.0 ) # Additional features punctuation_ratio = sum(1 for c in text if c in ".,!?;:") / max( len(text), 1 ) uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1) # Encode as quantum features features = [ text_len * np.pi, word_count * np.pi, char_diversity * np.pi, avg_word_len * np.pi, punctuation_ratio * np.pi, uppercase_ratio * np.pi, ] # Run quantum analysis measurements = self.circuit(features) # Calculate coherence as normalized measurement variance coherence = np.var(measurements) / (np.var(measurements) + 0.1) return float(np.clip(coherence, 0.0, 1.0)) except Exception as e: print(f"Quantum text analysis error: {e}") # Fallback to simple heuristic return min(1.0, len(text) / 100.0) * 0.7 + 0.2 class QuantumHybridEnv(BaseEnv): """Environment for training and evaluating quantum-classical hybrid models.""" def __init__( self, config: QuantumHybridConfig, server_configs: List[APIServerConfig], slurm=True, testing=False, ): super().__init__(config, server_configs, slurm, testing) self.config = config self.metrics_buffer = { "perplexity": [], "quantum_coherence": [], "combined_score": [], "quantum_variance": [], } # Initialize eval_metrics list (required for BaseEnv) self.eval_metrics = [] # Initialize quantum text analyzer self.quantum_analyzer = QuantumTextAnalyzer(n_qubits=config.n_qubits) # Track training iteration self.iter = 0 @classmethod def config_init(cls) -> Tuple[QuantumHybridConfig, List[APIServerConfig]]: """Initialize default configuration.""" config = QuantumHybridConfig( tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview", group_size=4, use_wandb=True, max_num_workers=-1, rollout_server_url="http://localhost:8000", # Atropos server total_steps=20, batch_size=-1, max_token_length=2048, data_path_to_save_groups="data/quantum_hybrid.jsonl", n_qubits=8, n_layers=3, dataset_name="wikitext", dataset_config="wikitext-2-raw-v1", sequence_length=256, base_model_name="gpt2", perplexity_weight=0.7, quantum_weight=0.3, ) # The server config here tells Atropos to route to your vLLM server # This should match whatever model name the Atropos server expects server_configs = [ APIServerConfig( model_name="hermes-3-8b", # This model name should be registered in Atropos base_url="http://localhost:9001/v1", # Your vLLM server api_key="x", # Placeholder for local server timeout=300, num_max_requests_at_once=8, num_requests_for_eval=4, health_check=False, server_type="openai", n_kwarg_is_ignored=False, rolling_buffer_length=100, ), ] return config, server_configs async def setup(self): """Set up the environment, including loading datasets.""" print( f"Setting up QuantumHybridEnv with quantum parameters: " f"{self.config.n_qubits} qubits, {self.config.n_layers} layers" ) # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load dataset try: dataset = load_dataset( self.config.dataset_name, self.config.dataset_config, split="train", streaming=True, ) self.dataset = dataset.take(10000) # Take first 10k examples eval_dataset = load_dataset( self.config.dataset_name, self.config.dataset_config, split="validation", streaming=True, ) self.eval_dataset = eval_dataset.take(100) # Take first 100 for eval print( f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}" ) except Exception as e: print(f"Failed to load dataset: {e}") # Create a mock dataset for testing self.dataset = [ { "text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns." } for i in range(1000) ] self.eval_dataset = [ {"text": f"Eval text {i} with various complexity levels."} for i in range(100) ] print("Using mock dataset") # Convert to lists for easier iteration if hasattr(self.dataset, "__iter__"): self.train_examples = list(self.dataset) else: self.train_examples = self.dataset if hasattr(self.eval_dataset, "__iter__"): self.eval_examples = list(self.eval_dataset) else: self.eval_examples = self.eval_dataset print(f"Loaded {len(self.train_examples)} training examples") print(f"Loaded {len(self.eval_examples)} evaluation examples") async def get_next_item(self): """Get the next training item from the dataset.""" # Get next instance from the dataset if self.iter >= len(self.train_examples): self.iter = 0 # Reset to beginning data_point = self.train_examples[self.iter] self.iter += 1 # Process text data if isinstance(data_point, dict) and "text" in data_point: text = data_point["text"] else: text = str(data_point) # Truncate text to reasonable length text = text[: self.config.sequence_length * 4] # Allow for some context # Create a simple continuation task words = text.split() if len(words) > 10: # Split at a random point for continuation split_idx = random.randint(5, min(len(words) - 5, 20)) prompt_text = " ".join(words[:split_idx]) target_text = " ".join(words[split_idx : split_idx + 20]) else: prompt_text = text[: len(text) // 2] target_text = text[len(text) // 2 :] # Convert to messages format for Atropos user_msg = {"role": "user", "content": f"Continue this text: {prompt_text}"} # Return as (messages, target) tuple return ([user_msg], target_text) async def collect_trajectories(self, item: Tuple) -> Tuple[ScoredDataGroup, List]: """Generate and collect model responses for scoring.""" messages, target_text = item print(f"Generating completions for: {messages[0]['content'][:50]}...") to_score = [] # Generate multiple completions with different temperatures temperatures = [0.6, 0.8, 1.0, 1.2][: self.config.group_size] for i, temp in enumerate(temperatures): try: # Use the Atropos server to generate completions completion = await self.server.completion( prompt=self.tokenizer.apply_chat_template(messages, tokenize=False), n=1, # Generate one completion at a time max_tokens=min(100, self.config.max_token_length // 4), temperature=temp, ) completion_text = completion.choices[0].text.strip() print( f"Generated completion {i+1} (T={temp}): {completion_text[:50]}..." ) # Create trajectory messages trajectory_messages = messages.copy() trajectory_messages.append( {"role": "assistant", "content": completion_text} ) # Add to scoring queue to_score.append( ( tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, completion_text, ) ) except Exception as e: print(f"Error generating completion {i+1} with temperature {temp}: {e}") # Create a mock completion as fallback mock_text = ( f"Mock quantum-enhanced response {i+1}: This demonstrates " f"coherent language generation with temperature {temp}." ) trajectory_messages = messages.copy() trajectory_messages.append({"role": "assistant", "content": mock_text}) to_score.append( ( tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, mock_text, ) ) # Score the generated trajectories scored_data = await self.score(to_score) return scored_data, [] async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]: """Score model outputs based on perplexity and quantum metrics.""" if not rollout_group_data: return None scores = ScoredDataGroup() scores["tokens"] = [] scores["masks"] = [] scores["scores"] = [] print(f"Scoring {len(rollout_group_data)} completions...") for item in rollout_group_data: frozen_messages, target_text, completion_text = item # Convert frozen messages back to regular format messages = [dict(frozen_msg) for frozen_msg in frozen_messages] # Convert messages to tokens try: tokenized = tokenize_for_trainer(self.tokenizer, messages) tokens = tokenized["tokens"] masks = tokenized["masks"] except Exception as e: print(f"Tokenization error: {e}") continue # Calculate text similarity score (proxy for perplexity) completion_words = set(completion_text.lower().split()) target_words = set(target_text.lower().split()) # Jaccard similarity if completion_words and target_words: intersection = len(completion_words & target_words) union = len(completion_words | target_words) similarity_score = intersection / union if union > 0 else 0.0 else: similarity_score = 0.0 # Length penalty/bonus len_ratio = len(completion_text) / max(len(target_text), 1) len_score = 1.0 - abs(1.0 - len_ratio) if len_ratio > 0 else 0.0 # Combined perplexity-like score perplexity_score = 0.7 * similarity_score + 0.3 * len_score # Quantum coherence analysis quantum_coherence = self.quantum_analyzer.analyze_text(completion_text) # Calculate quantum variance for additional insight quantum_variance = ( abs(quantum_coherence - 0.5) * 2 ) # Distance from maximum entropy # Combined score using weighted sum combined_score = ( self.config.perplexity_weight * perplexity_score + self.config.quantum_weight * quantum_coherence ) print( f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, " f"Combined: {combined_score:.3f}" ) # Update metrics buffer self.metrics_buffer["perplexity"].append(perplexity_score) self.metrics_buffer["quantum_coherence"].append(quantum_coherence) self.metrics_buffer["combined_score"].append(combined_score) self.metrics_buffer["quantum_variance"].append(quantum_variance) # Store data for training (scale to [-1, 1] range) scores["tokens"].append(tokens) scores["masks"].append(masks) scores["scores"].append(2 * combined_score - 1) return scores if scores["scores"] else None async def evaluate(self, *args, **kwargs): """Evaluate the model on a test set.""" print("Running quantum-enhanced evaluation...") eval_scores = [] quantum_scores = [] # Get evaluation examples eval_examples = self.eval_examples[: min(5, len(self.eval_examples))] for example in eval_examples: try: # Process text if isinstance(example, dict) and "text" in example: text = example["text"] else: text = str(example) # Create continuation task words = text.split() if len(words) > 8: split_idx = len(words) // 2 prompt_text = " ".join(words[:split_idx]) target_text = " ".join(words[split_idx:]) else: prompt_text = text[: len(text) // 2] target_text = text[len(text) // 2 :] # Create messages messages = [ {"role": "user", "content": f"Continue this text: {prompt_text}"} ] # Generate completion completion = await self.server.completion( prompt=self.tokenizer.apply_chat_template(messages, tokenize=False), n=1, max_tokens=50, temperature=0.7, split="eval", ) completion_text = completion.choices[0].text.strip() # Calculate metrics completion_words = set(completion_text.lower().split()) target_words = set(target_text.lower().split()) if completion_words and target_words: intersection = len(completion_words & target_words) union = len(completion_words | target_words) similarity = intersection / union if union > 0 else 0.0 else: similarity = 0.0 # Quantum analysis quantum_coherence = self.quantum_analyzer.analyze_text(completion_text) # Combined evaluation score eval_score = ( self.config.perplexity_weight * similarity + self.config.quantum_weight * quantum_coherence ) eval_scores.append(eval_score) quantum_scores.append(quantum_coherence) except Exception as e: print(f"Evaluation error: {e}") continue # Log evaluation metrics if eval_scores: avg_eval_score = sum(eval_scores) / len(eval_scores) avg_quantum_score = sum(quantum_scores) / len(quantum_scores) self.eval_metrics.append(("eval/combined_score", avg_eval_score)) self.eval_metrics.append(("eval/quantum_coherence", avg_quantum_score)) self.eval_metrics.append( ("eval/perplexity_weight", self.config.perplexity_weight) ) self.eval_metrics.append( ("eval/quantum_weight", self.config.quantum_weight) ) print( f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}" ) async def wandb_log(self, wandb_metrics: Optional[Dict] = None): """Log metrics to Weights & Biases.""" if wandb_metrics is None: wandb_metrics = {} # Calculate and add metrics from buffer for metric_name, values in self.metrics_buffer.items(): if values: wandb_metrics[f"train/{metric_name}_avg"] = sum(values) / len(values) wandb_metrics[f"train/{metric_name}_max"] = max(values) wandb_metrics[f"train/{metric_name}_min"] = min(values) wandb_metrics[f"train/{metric_name}_std"] = np.std(values) # Add quantum-specific metrics wandb_metrics["quantum/n_qubits"] = self.config.n_qubits wandb_metrics["quantum/n_layers"] = self.config.n_layers wandb_metrics["quantum/weight"] = self.config.quantum_weight wandb_metrics["train/perplexity_weight"] = self.config.perplexity_weight # Clear the buffer self.metrics_buffer = {key: [] for key in self.metrics_buffer} # Add evaluation metrics for name, value in self.eval_metrics: wandb_metrics[name] = value self.eval_metrics = [] # Log to wandb using the parent method await super().wandb_log(wandb_metrics) if __name__ == "__main__": # Launch the environment with CLI arguments QuantumHybridEnv.cli()