linted, moved to community folder

This commit is contained in:
Shannon Sands 2025-05-26 14:10:26 +10:00
parent 20c6e9d8d7
commit b845c635d4
7 changed files with 573 additions and 290 deletions

View file

@ -918,6 +918,190 @@ python environments/community/poker_holdem/poker_env.py process \
**Requirements**: datasets, transformers, wandb, atroposlib
### 18. Quantum-Classical Hybrid Language Model Environment (`quantum_hybrid/`)
**Author**: [jeannemtl](https://github.com/jeannemtl)
**Purpose**: Train quantum-enhanced language models by combining classical transformers with quantum circuits using PennyLane and PyTorch
A novel environment that implements quantum-classical hybrid architecture for next-word prediction, trained on high-quality text generated by Hermes-3-70B. The key innovation is using quantum circuits to enhance traditional neural networks for language modeling tasks, exploring potential quantum advantages in natural language processing.
**Research Question**: Can quantum circuits provide advantages over purely classical approaches in natural language processing tasks?
**Architecture Overview**:
- **Data Flow**: Input Prompts → Hermes-3-70B (text generation) → Hybrid Model Training → Quantum-Enhanced Predictions
- **Hybrid Model Components**:
- **Classical Pathway**: Standard transformer-style neural network head
- **Quantum Pathway**: Dimensionality reduction (768D → 8D) → Two quantum circuit layers → Quantum-to-vocabulary mapping
- **Learnable Mixing**: Parameter α balances classical vs quantum contributions
**Quantum Circuit Design**:
- **8 qubits with 3 parameterized layers**
- **RY rotation gates** for classical data encoding
- **CNOT gates** creating entanglement patterns
- **Pauli-Z measurements** for classical output extraction
- **Ring topology** for full qubit connectivity
**Dual Implementation Approach**:
The environment includes two complementary implementations:
**1. Optimized Hybrid Model (`atropos.py`)**:
- **Synthetic Training**: Uses simplified tokenizer and mock hidden states for rapid experimentation
- **Quantum Integration**: Full quantum circuit implementation with PennyLane
- **Hybrid Architecture**: Learnable mixing between classical and quantum pathways
- **Training Loop**: Direct optimization of quantum parameters via gradient descent
- **Evaluation**: Entropy-based comparison of hybrid vs classical predictions
**2. Dataset-Driven Training (`atopos_quant.py`)**:
- **Real Data Processing**: Uses WikiText dataset with HuggingFace integration
- **Quantum Text Analysis**: Standalone quantum analyzer for text coherence measurement
- **Server Integration**: Compatible with Atropos server infrastructure
- **Comprehensive Metrics**: Perplexity, quantum coherence, and combined scoring
- **Production Ready**: Full tokenization and dataset management
**Quantum Text Analysis Features**:
- **Text Feature Extraction**: Length, word count, character diversity, punctuation patterns
- **Quantum Encoding**: Features mapped to quantum states via rotation gates
- **Entanglement Patterns**: Complex qubit interactions for linguistic analysis
- **Coherence Measurement**: Quantum variance as text quality indicator
- **Fallback Mechanisms**: Graceful degradation when quantum circuits fail
**Training Strategy - Quantum-Enhanced Knowledge Distillation**:
1. **Teacher Model**: Hermes-3-70B generates diverse, high-quality responses
2. **Student Model**: Hybrid quantum-classical model learns next-word prediction
3. **Comparison**: Direct evaluation of quantum vs classical pathways within same model
4. **Optimization**: Both classical and quantum parameters trained via gradient descent
**Key Metrics & Evaluation**:
**Training Metrics**:
- `train/hybrid_loss`: Combined quantum-classical model loss
- `train/classical_loss`: Baseline classical-only model loss
- `train/quantum_loss`: Quantum-specific loss component
- `train/alpha_value`: Mixing parameter (0 = full quantum, 1 = full classical)
**Evaluation Metrics**:
- `eval/hybrid_performance`: Entropy-based comparison of hybrid vs classical predictions
- `eval/quantum_weight`: Current quantum contribution (1 - α)
- `train/quantum_coherence`: Measure of quantum circuit effectiveness
**Model Metrics**:
- `model/alpha`: Real-time mixing parameter
- `model/quantum_contribution`: Percentage of quantum influence
**Interpretation Guide**:
- **Decreasing hybrid_loss**: Model improving at next-word prediction
- **Stable alpha_value**: Balanced classical-quantum integration
- **High quantum_coherence**: Quantum circuits contributing meaningfully
- **hybrid_performance > 0.5**: Quantum enhancement provides benefits
**Technical Implementation Details**:
**Quantum Circuit Architecture**:
```python
# Data encoding
qml.RY(classical_data, wires=qubit)
# Parameterized layers
for layer in range(n_layers):
for qubit in range(n_qubits):
qml.RY(learnable_params[layer, qubit], wires=qubit)
# Entanglement pattern
for i in range(n_qubits - 1):
qml.CNOT(wires=[i, i + 1])
qml.CNOT(wires=[n_qubits - 1, 0]) # Ring topology
# Measurement
[qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
```
**Training Process**:
1. **Forward Pass**: Hidden states → quantum circuits → predictions
2. **Loss Calculation**: Cross-entropy on next-word prediction
3. **Backpropagation**: Gradients through quantum circuits via parameter-shift rule
4. **Optimization**: Adam optimizer updates both classical and quantum parameters
**Novel Contributions**:
- **First quantum-enhanced Atropos environment**
- **Hybrid architecture balancing quantum and classical processing**
- **Knowledge distillation from large classical models to small quantum models**
- **Quantum-aware evaluation metrics for NLP tasks**
**Current Limitations**:
- **Simulated Quantum**: Uses classical simulation (no quantum hardware)
- **Synthetic Features**: Uses random hidden states (not real text embeddings in optimized version)
- **Scale**: Limited to 8 qubits due to exponential simulation cost
- **Evaluation**: Simple entropy comparison (more sophisticated metrics possible)
**Potential Applications**:
- **Quantum NLP Research**: Differentiable quantum circuits for language tasks
- **Hybrid Model Architectures**: Resource-constrained environments with quantum enhancement
- **Novel Optimization**: Combining classical and quantum approaches
- **Benchmark Creation**: Quantum machine learning evaluation in language tasks
**Future Research Directions**:
**Immediate Improvements**:
- **Real Text Processing**: Replace synthetic hidden states with actual transformer embeddings
- **Advanced Quantum Circuits**: Implement quantum attention mechanisms
- **Scaling Studies**: Investigate qubit count vs performance relationships
**Long-term Goals**:
- **Quantum Hardware**: Deploy on IBM Quantum, IonQ, or other quantum computers
- **Larger Models**: Scale to 100+ qubit systems when available
- **Quantum Advantage**: Identify specific NLP tasks where quantum provides provable benefits
- **Production Systems**: Develop practical quantum-enhanced language models
**Configuration Options**:
- **Quantum Parameters**: Configurable qubit count (default: 8) and layer depth (default: 3)
- **Training Settings**: Learning rate, batch size, total steps, evaluation frequency
- **Model Architecture**: Base model selection, vocabulary size, hidden dimensions
- **Hybrid Weighting**: Adjustable balance between classical and quantum contributions
- **Dataset Selection**: WikiText variants or custom text datasets
**Setup Requirements**:
1. **PennyLane**: Quantum computing framework
2. **PyTorch**: Deep learning and automatic differentiation
3. **Transformers**: Tokenization and model utilities
4. **Datasets**: HuggingFace dataset loading
5. **NumPy**: Numerical computations
6. **WandB**: Experiment tracking and visualization
**Installation & Usage**:
```bash
# Install quantum dependencies
pip install pennylane torch transformers datasets numpy wandb
# Run optimized hybrid training
python environments/community/quantum_hybrid/atropos.py process \
--env.n_qubits 8 \
--env.n_layers 3 \
--env.total_steps 50 \
--env.quantum_weight 0.3
# Run dataset-driven training
python environments/community/quantum_hybrid/atopos_quant.py process \
--env.dataset_name wikitext \
--env.dataset_config wikitext-2-raw-v1 \
--env.n_qubits 8
```
**Live Experiment Tracking**: Monitor training progress and quantum metrics at WandB dashboard with real-time visualization of quantum-classical balance and performance metrics.
**Research Impact**: This environment represents cutting-edge research in quantum machine learning for NLP. While quantum advantages are still under investigation, the framework provides a foundation for future breakthroughs in quantum-enhanced language processing.
**Repository Structure**:
```
environments/community/quantum_hybrid/
├── atropos.py # Optimized hybrid model implementation
├── atopos_quant.py # Dataset-driven quantum training
├── requirements.txt # Python dependencies
├── README.md # Detailed documentation
├── quantum_hybrid_artifacts.tar.gz # Training artifacts
└── quantum_latest_artifacts.tar.gz # Latest training data
```
**Requirements**: pennylane, torch, transformers, datasets, numpy, pydantic, atroposlib
---
## Support

View file

@ -1,15 +1,11 @@
import asyncio
import itertools
import random
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
import pennylane as qml
import torch
import wandb
from datasets import load_dataset
from pydantic import Field
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoTokenizer
from atroposlib.envs.base import (
APIServerConfig,
@ -25,19 +21,31 @@ class QuantumHybridConfig(BaseEnvConfig):
# Quantum circuit parameters
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
n_layers: int = Field(3, description="Number of quantum layers to use in the circuit")
n_layers: int = Field(
3, description="Number of quantum layers to use in the circuit"
)
# Dataset parameters
dataset_name: str = Field("wikitext", description="Dataset to use for training/evaluation")
dataset_config: str = Field("wikitext-2-raw-v1", description="Dataset configuration")
dataset_name: str = Field(
"wikitext", description="Dataset to use for training/evaluation"
)
dataset_config: str = Field(
"wikitext-2-raw-v1", description="Dataset configuration"
)
sequence_length: int = Field(256, description="Length of sequences to process")
# Base model parameters
base_model_name: str = Field("gpt2", description="Base model for hybrid experiments")
base_model_name: str = Field(
"gpt2", description="Base model for hybrid experiments"
)
# Training parameters
perplexity_weight: float = Field(0.7, description="Weight for perplexity in scoring")
quantum_weight: float = Field(0.3, description="Weight for quantum-specific metrics")
perplexity_weight: float = Field(
0.7, description="Weight for perplexity in scoring"
)
quantum_weight: float = Field(
0.3, description="Weight for quantum-specific metrics"
)
class QuantumTextAnalyzer:
@ -55,15 +63,15 @@ class QuantumTextAnalyzer:
# Create entanglement patterns
for i in range(self.n_qubits - 1):
qml.CNOT(wires=[i, i+1])
qml.CNOT(wires=[i, i + 1])
# Ring closure
if self.n_qubits > 1:
qml.CNOT(wires=[self.n_qubits-1, 0])
qml.CNOT(wires=[self.n_qubits - 1, 0])
# Additional entanglement for complex analysis
for i in range(0, self.n_qubits-2, 2):
qml.CNOT(wires=[i, i+2])
for i in range(0, self.n_qubits - 2, 2):
qml.CNOT(wires=[i, i + 2])
# Measure coherence
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
@ -77,10 +85,16 @@ class QuantumTextAnalyzer:
text_len = min(len(text), 200) / 200.0
word_count = len(text.split()) / 100.0 if text.split() else 0.0
char_diversity = len(set(text.lower())) / 26.0 if text else 0.0
avg_word_len = np.mean([len(word) for word in text.split()]) / 15.0 if text.split() else 0.0
avg_word_len = (
np.mean([len(word) for word in text.split()]) / 15.0
if text.split()
else 0.0
)
# Additional features
punctuation_ratio = sum(1 for c in text if c in '.,!?;:') / max(len(text), 1)
punctuation_ratio = sum(1 for c in text if c in ".,!?;:") / max(
len(text), 1
)
uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
# Encode as quantum features
@ -90,7 +104,7 @@ class QuantumTextAnalyzer:
char_diversity * np.pi,
avg_word_len * np.pi,
punctuation_ratio * np.pi,
uppercase_ratio * np.pi
uppercase_ratio * np.pi,
]
# Run quantum analysis
@ -178,7 +192,10 @@ class QuantumHybridEnv(BaseEnv):
async def setup(self):
"""Set up the environment, including loading datasets."""
print(f"Setting up QuantumHybridEnv with quantum parameters: {self.config.n_qubits} qubits, {self.config.n_layers} layers")
print(
f"Setting up QuantumHybridEnv with quantum parameters: "
f"{self.config.n_qubits} qubits, {self.config.n_layers} layers"
)
# Initialize tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
@ -191,7 +208,7 @@ class QuantumHybridEnv(BaseEnv):
self.config.dataset_name,
self.config.dataset_config,
split="train",
streaming=True
streaming=True,
)
self.dataset = dataset.take(10000) # Take first 10k examples
@ -199,26 +216,36 @@ class QuantumHybridEnv(BaseEnv):
self.config.dataset_name,
self.config.dataset_config,
split="validation",
streaming=True
streaming=True,
)
self.eval_dataset = eval_dataset.take(100) # Take first 100 for eval
print(f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}")
print(
f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}"
)
except Exception as e:
print(f"Failed to load dataset: {e}")
# Create a mock dataset for testing
self.dataset = [{"text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns."} for i in range(1000)]
self.eval_dataset = [{"text": f"Eval text {i} with various complexity levels."} for i in range(100)]
self.dataset = [
{
"text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns."
}
for i in range(1000)
]
self.eval_dataset = [
{"text": f"Eval text {i} with various complexity levels."}
for i in range(100)
]
print("Using mock dataset")
# Convert to lists for easier iteration
if hasattr(self.dataset, '__iter__'):
if hasattr(self.dataset, "__iter__"):
self.train_examples = list(self.dataset)
else:
self.train_examples = self.dataset
if hasattr(self.eval_dataset, '__iter__'):
if hasattr(self.eval_dataset, "__iter__"):
self.eval_examples = list(self.eval_dataset)
else:
self.eval_examples = self.eval_dataset
@ -242,18 +269,18 @@ class QuantumHybridEnv(BaseEnv):
text = str(data_point)
# Truncate text to reasonable length
text = text[:self.config.sequence_length * 4] # Allow for some context
text = text[: self.config.sequence_length * 4] # Allow for some context
# Create a simple continuation task
words = text.split()
if len(words) > 10:
# Split at a random point for continuation
split_idx = random.randint(5, min(len(words) - 5, 20))
prompt_text = ' '.join(words[:split_idx])
target_text = ' '.join(words[split_idx:split_idx + 20])
prompt_text = " ".join(words[:split_idx])
target_text = " ".join(words[split_idx : split_idx + 20])
else:
prompt_text = text[:len(text)//2]
target_text = text[len(text)//2:]
prompt_text = text[: len(text) // 2]
target_text = text[len(text) // 2 :]
# Convert to messages format for Atropos
user_msg = {"role": "user", "content": f"Continue this text: {prompt_text}"}
@ -270,7 +297,7 @@ class QuantumHybridEnv(BaseEnv):
to_score = []
# Generate multiple completions with different temperatures
temperatures = [0.6, 0.8, 1.0, 1.2][:self.config.group_size]
temperatures = [0.6, 0.8, 1.0, 1.2][: self.config.group_size]
for i, temp in enumerate(temperatures):
try:
@ -283,7 +310,9 @@ class QuantumHybridEnv(BaseEnv):
)
completion_text = completion.choices[0].text.strip()
print(f"Generated completion {i+1} (T={temp}): {completion_text[:50]}...")
print(
f"Generated completion {i+1} (T={temp}): {completion_text[:50]}..."
)
# Create trajectory messages
trajectory_messages = messages.copy()
@ -292,15 +321,30 @@ class QuantumHybridEnv(BaseEnv):
)
# Add to scoring queue
to_score.append((tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, completion_text))
to_score.append(
(
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
target_text,
completion_text,
)
)
except Exception as e:
print(f"Error generating completion {i+1} with temperature {temp}: {e}")
# Create a mock completion as fallback
mock_text = f"Mock quantum-enhanced response {i+1}: This demonstrates coherent language generation with temperature {temp}."
mock_text = (
f"Mock quantum-enhanced response {i+1}: This demonstrates "
f"coherent language generation with temperature {temp}."
)
trajectory_messages = messages.copy()
trajectory_messages.append({"role": "assistant", "content": mock_text})
to_score.append((tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, mock_text))
to_score.append(
(
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
target_text,
mock_text,
)
)
# Score the generated trajectories
scored_data = await self.score(to_score)
@ -357,15 +401,20 @@ class QuantumHybridEnv(BaseEnv):
quantum_coherence = self.quantum_analyzer.analyze_text(completion_text)
# Calculate quantum variance for additional insight
quantum_variance = abs(quantum_coherence - 0.5) * 2 # Distance from maximum entropy
quantum_variance = (
abs(quantum_coherence - 0.5) * 2
) # Distance from maximum entropy
# Combined score using weighted sum
combined_score = (
self.config.perplexity_weight * perplexity_score +
self.config.quantum_weight * quantum_coherence
self.config.perplexity_weight * perplexity_score
+ self.config.quantum_weight * quantum_coherence
)
print(f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, Combined: {combined_score:.3f}")
print(
f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, "
f"Combined: {combined_score:.3f}"
)
# Update metrics buffer
self.metrics_buffer["perplexity"].append(perplexity_score)
@ -387,7 +436,7 @@ class QuantumHybridEnv(BaseEnv):
quantum_scores = []
# Get evaluation examples
eval_examples = self.eval_examples[:min(5, len(self.eval_examples))]
eval_examples = self.eval_examples[: min(5, len(self.eval_examples))]
for example in eval_examples:
try:
@ -401,14 +450,16 @@ class QuantumHybridEnv(BaseEnv):
words = text.split()
if len(words) > 8:
split_idx = len(words) // 2
prompt_text = ' '.join(words[:split_idx])
target_text = ' '.join(words[split_idx:])
prompt_text = " ".join(words[:split_idx])
target_text = " ".join(words[split_idx:])
else:
prompt_text = text[:len(text)//2]
target_text = text[len(text)//2:]
prompt_text = text[: len(text) // 2]
target_text = text[len(text) // 2 :]
# Create messages
messages = [{"role": "user", "content": f"Continue this text: {prompt_text}"}]
messages = [
{"role": "user", "content": f"Continue this text: {prompt_text}"}
]
# Generate completion
completion = await self.server.completion(
@ -437,8 +488,8 @@ class QuantumHybridEnv(BaseEnv):
# Combined evaluation score
eval_score = (
self.config.perplexity_weight * similarity +
self.config.quantum_weight * quantum_coherence
self.config.perplexity_weight * similarity
+ self.config.quantum_weight * quantum_coherence
)
eval_scores.append(eval_score)
@ -455,10 +506,16 @@ class QuantumHybridEnv(BaseEnv):
self.eval_metrics.append(("eval/combined_score", avg_eval_score))
self.eval_metrics.append(("eval/quantum_coherence", avg_quantum_score))
self.eval_metrics.append(("eval/perplexity_weight", self.config.perplexity_weight))
self.eval_metrics.append(("eval/quantum_weight", self.config.quantum_weight))
self.eval_metrics.append(
("eval/perplexity_weight", self.config.perplexity_weight)
)
self.eval_metrics.append(
("eval/quantum_weight", self.config.quantum_weight)
)
print(f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}")
print(
f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}"
)
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics to Weights & Biases."""

View file

@ -1,16 +1,11 @@
import asyncio
import itertools
import random
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
import pennylane as qml
import torch
import torch.nn.functional as F
import wandb
from datasets import load_dataset
from pydantic import Field
from transformers import AutoModelForCausalLM, AutoTokenizer
from atroposlib.envs.base import (
APIServerConfig,
@ -18,7 +13,6 @@ from atroposlib.envs.base import (
BaseEnvConfig,
ScoredDataGroup,
)
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
class QuantumHybridConfig(BaseEnvConfig):
@ -26,27 +20,45 @@ class QuantumHybridConfig(BaseEnvConfig):
# Quantum circuit parameters
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
n_layers: int = Field(3, description="Number of quantum layers to use in the circuit")
n_layers: int = Field(
3, description="Number of quantum layers to use in the circuit"
)
# Dataset parameters
dataset_name: str = Field("wikitext", description="Dataset to use for training/evaluation")
dataset_config: str = Field("wikitext-2-raw-v1", description="Dataset configuration")
dataset_name: str = Field(
"wikitext", description="Dataset to use for training/evaluation"
)
dataset_config: str = Field(
"wikitext-2-raw-v1", description="Dataset configuration"
)
sequence_length: int = Field(256, description="Length of sequences to process")
# Base model parameters
base_model_name: str = Field("gpt2", description="Base model for hybrid experiments")
base_model_name: str = Field(
"gpt2", description="Base model for hybrid experiments"
)
# Training parameters
eval_every: int = Field(10, description="Evaluate every N training steps")
perplexity_weight: float = Field(0.7, description="Weight for perplexity in scoring")
quantum_weight: float = Field(0.3, description="Weight for quantum-specific metrics")
perplexity_weight: float = Field(
0.7, description="Weight for perplexity in scoring"
)
quantum_weight: float = Field(
0.3, description="Weight for quantum-specific metrics"
)
# Hybrid model parameters
train_hybrid_model: bool = Field(True, description="Whether to train the hybrid model")
learning_rate: float = Field(1e-4, description="Learning rate for quantum parameters")
train_hybrid_model: bool = Field(
True, description="Whether to train the hybrid model"
)
learning_rate: float = Field(
1e-4, description="Learning rate for quantum parameters"
)
# Comparison parameters
compare_with_classical: bool = Field(True, description="Compare hybrid with classical model")
compare_with_classical: bool = Field(
True, description="Compare hybrid with classical model"
)
class OptimizedQuantumLayer(torch.nn.Module):
@ -71,10 +83,10 @@ class OptimizedQuantumLayer(torch.nn.Module):
qml.RY(inputs[i], wires=i)
# Apply parameterized quantum layers
for l in range(self.n_layers):
for layer in range(self.n_layers):
# Rotation gates with learnable parameters
for i in range(self.n_qubits):
qml.RY(params[l, i], wires=i)
qml.RY(params[layer, i], wires=i)
# Entanglement between qubits
for i in range(self.n_qubits - 1):
@ -100,10 +112,12 @@ class OptimizedQuantumLayer(torch.nn.Module):
# Process each item in the batch
for i in range(batch_size):
# Normalize input values to [-1, 1] for quantum embedding
x_norm = torch.tanh(x[i, :self.n_qubits])
x_norm = torch.tanh(x[i, : self.n_qubits])
# Get quantum circuit output
try:
quantum_out = torch.tensor(self.circuit(x_norm, self.params), dtype=torch.float32)
quantum_out = torch.tensor(
self.circuit(x_norm, self.params), dtype=torch.float32
)
results.append(quantum_out)
except Exception as e:
print(f"Quantum circuit error: {e}")
@ -146,8 +160,6 @@ class OptimizedHybridModel(torch.nn.Module):
hidden_states: [batch_size, hidden_size] tensor
return_classical: if True, return classical logits for comparison
"""
batch_size = hidden_states.shape[0]
# Classical pathway
classical_logits = self.classical_head(hidden_states)
@ -266,11 +278,13 @@ class QuantumHybridEnv(BaseEnv):
base_model_name=self.config.base_model_name,
n_qubits=self.config.n_qubits,
n_layers=self.config.n_layers,
vocab_size=self.tokenizer.vocab_size
vocab_size=self.tokenizer.vocab_size,
)
# Initialize optimizer for quantum parameters
self.optimizer = torch.optim.Adam(self.hybrid_model.parameters(), lr=self.config.learning_rate)
self.optimizer = torch.optim.Adam(
self.hybrid_model.parameters(), lr=self.config.learning_rate
)
# Sample texts for training
self.sample_texts = [
@ -283,7 +297,7 @@ class QuantumHybridEnv(BaseEnv):
"Artificial intelligence mimics human cognitive functions.",
"Computer vision enables machines to interpret images.",
"Robotics integrates mechanical and software engineering.",
"Data science extracts insights from large datasets."
"Data science extracts insights from large datasets.",
]
self.iter = 0
@ -326,12 +340,22 @@ class QuantumHybridEnv(BaseEnv):
except Exception as e:
print(f"API error: {e}, using fallback responses")
# Fallback responses
completions = type('obj', (object,), {
'choices': [
type('choice', (object,), {'text': f"This is response {i+1} to: {user_content[:50]}..."})()
completions = type(
"obj",
(object,),
{
"choices": [
type(
"choice",
(object,),
{
"text": f"This is response {i+1} to: {user_content[:50]}..."
},
)()
for i in range(self.config.group_size)
]
})()
},
)()
to_score = []
@ -341,7 +365,9 @@ class QuantumHybridEnv(BaseEnv):
# Create trajectory
trajectory_messages = messages.copy()
trajectory_messages.append({"role": "assistant", "content": completion_text})
trajectory_messages.append(
{"role": "assistant", "content": completion_text}
)
# Train hybrid model
if self.config.train_hybrid_model:
@ -385,7 +411,9 @@ class QuantumHybridEnv(BaseEnv):
# Log metrics
self.metrics_buffer["hybrid_loss"].append(hybrid_loss.item())
self.metrics_buffer["classical_loss"].append(classical_loss.item())
self.metrics_buffer["alpha_value"].append(torch.sigmoid(self.hybrid_model.alpha).item())
self.metrics_buffer["alpha_value"].append(
torch.sigmoid(self.hybrid_model.alpha).item()
)
except Exception as e:
print(f"Training error: {e}")
@ -417,22 +445,24 @@ class QuantumHybridEnv(BaseEnv):
text_quality_score = min(len(generated_text) / 100, 1.0)
# Quantum coherence from hybrid model
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
try:
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
with torch.no_grad():
hybrid_logits = self.hybrid_model(hidden_states)
quantum_contribution = 1 - torch.sigmoid(self.hybrid_model.alpha).item()
self.hybrid_model(hidden_states) # Run forward pass
quantum_contribution = (
1 - torch.sigmoid(self.hybrid_model.alpha).item()
)
quantum_score = quantum_contribution + 0.3 * random.random()
except:
except Exception:
quantum_score = 0.5 + 0.3 * random.random()
else:
quantum_score = 0.5 + 0.3 * random.random()
# Combined score
combined_score = (
self.config.perplexity_weight * text_quality_score +
self.config.quantum_weight * quantum_score
self.config.perplexity_weight * text_quality_score
+ self.config.quantum_weight * quantum_score
)
# Update metrics
@ -455,19 +485,29 @@ class QuantumHybridEnv(BaseEnv):
for text in self.sample_texts[:5]:
try:
# Test hybrid model performance
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
with torch.no_grad():
hybrid_logits = self.hybrid_model(hidden_states)
classical_logits = self.hybrid_model(hidden_states, return_classical=True)
classical_logits = self.hybrid_model(
hidden_states, return_classical=True
)
# Compare distributions
hybrid_entropy = -torch.sum(F.softmax(hybrid_logits, dim=-1) * F.log_softmax(hybrid_logits, dim=-1))
classical_entropy = -torch.sum(F.softmax(classical_logits, dim=-1) * F.log_softmax(classical_logits, dim=-1))
hybrid_entropy = -torch.sum(
F.softmax(hybrid_logits, dim=-1)
* F.log_softmax(hybrid_logits, dim=-1)
)
classical_entropy = -torch.sum(
F.softmax(classical_logits, dim=-1)
* F.log_softmax(classical_logits, dim=-1)
)
# Score based on entropy difference
score = 1.0 - abs(hybrid_entropy - classical_entropy) / max(hybrid_entropy, classical_entropy)
score = 1.0 - abs(hybrid_entropy - classical_entropy) / max(
hybrid_entropy, classical_entropy
)
eval_scores.append(score.item())
except Exception as e:
@ -479,7 +519,7 @@ class QuantumHybridEnv(BaseEnv):
self.eval_metrics.append(("eval/hybrid_performance", avg_score))
# Log current alpha value
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
alpha_val = torch.sigmoid(self.hybrid_model.alpha).item()
self.eval_metrics.append(("eval/alpha_value", alpha_val))
self.eval_metrics.append(("eval/quantum_weight", 1 - alpha_val))
@ -506,9 +546,11 @@ class QuantumHybridEnv(BaseEnv):
self.eval_metrics = []
# Log hybrid model parameters
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
wandb_metrics["model/alpha"] = torch.sigmoid(self.hybrid_model.alpha).item()
wandb_metrics["model/quantum_contribution"] = 1 - torch.sigmoid(self.hybrid_model.alpha).item()
wandb_metrics["model/quantum_contribution"] = (
1 - torch.sigmoid(self.hybrid_model.alpha).item()
)
await super().wandb_log(wandb_metrics)