mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
linted, moved to community folder
This commit is contained in:
parent
20c6e9d8d7
commit
b845c635d4
7 changed files with 573 additions and 290 deletions
|
|
@ -918,6 +918,190 @@ python environments/community/poker_holdem/poker_env.py process \
|
|||
|
||||
**Requirements**: datasets, transformers, wandb, atroposlib
|
||||
|
||||
### 18. Quantum-Classical Hybrid Language Model Environment (`quantum_hybrid/`)
|
||||
**Author**: [jeannemtl](https://github.com/jeannemtl)
|
||||
**Purpose**: Train quantum-enhanced language models by combining classical transformers with quantum circuits using PennyLane and PyTorch
|
||||
|
||||
A novel environment that implements quantum-classical hybrid architecture for next-word prediction, trained on high-quality text generated by Hermes-3-70B. The key innovation is using quantum circuits to enhance traditional neural networks for language modeling tasks, exploring potential quantum advantages in natural language processing.
|
||||
|
||||
**Research Question**: Can quantum circuits provide advantages over purely classical approaches in natural language processing tasks?
|
||||
|
||||
**Architecture Overview**:
|
||||
- **Data Flow**: Input Prompts → Hermes-3-70B (text generation) → Hybrid Model Training → Quantum-Enhanced Predictions
|
||||
- **Hybrid Model Components**:
|
||||
- **Classical Pathway**: Standard transformer-style neural network head
|
||||
- **Quantum Pathway**: Dimensionality reduction (768D → 8D) → Two quantum circuit layers → Quantum-to-vocabulary mapping
|
||||
- **Learnable Mixing**: Parameter α balances classical vs quantum contributions
|
||||
|
||||
**Quantum Circuit Design**:
|
||||
- **8 qubits with 3 parameterized layers**
|
||||
- **RY rotation gates** for classical data encoding
|
||||
- **CNOT gates** creating entanglement patterns
|
||||
- **Pauli-Z measurements** for classical output extraction
|
||||
- **Ring topology** for full qubit connectivity
|
||||
|
||||
**Dual Implementation Approach**:
|
||||
The environment includes two complementary implementations:
|
||||
|
||||
**1. Optimized Hybrid Model (`atropos.py`)**:
|
||||
- **Synthetic Training**: Uses simplified tokenizer and mock hidden states for rapid experimentation
|
||||
- **Quantum Integration**: Full quantum circuit implementation with PennyLane
|
||||
- **Hybrid Architecture**: Learnable mixing between classical and quantum pathways
|
||||
- **Training Loop**: Direct optimization of quantum parameters via gradient descent
|
||||
- **Evaluation**: Entropy-based comparison of hybrid vs classical predictions
|
||||
|
||||
**2. Dataset-Driven Training (`atopos_quant.py`)**:
|
||||
- **Real Data Processing**: Uses WikiText dataset with HuggingFace integration
|
||||
- **Quantum Text Analysis**: Standalone quantum analyzer for text coherence measurement
|
||||
- **Server Integration**: Compatible with Atropos server infrastructure
|
||||
- **Comprehensive Metrics**: Perplexity, quantum coherence, and combined scoring
|
||||
- **Production Ready**: Full tokenization and dataset management
|
||||
|
||||
**Quantum Text Analysis Features**:
|
||||
- **Text Feature Extraction**: Length, word count, character diversity, punctuation patterns
|
||||
- **Quantum Encoding**: Features mapped to quantum states via rotation gates
|
||||
- **Entanglement Patterns**: Complex qubit interactions for linguistic analysis
|
||||
- **Coherence Measurement**: Quantum variance as text quality indicator
|
||||
- **Fallback Mechanisms**: Graceful degradation when quantum circuits fail
|
||||
|
||||
**Training Strategy - Quantum-Enhanced Knowledge Distillation**:
|
||||
1. **Teacher Model**: Hermes-3-70B generates diverse, high-quality responses
|
||||
2. **Student Model**: Hybrid quantum-classical model learns next-word prediction
|
||||
3. **Comparison**: Direct evaluation of quantum vs classical pathways within same model
|
||||
4. **Optimization**: Both classical and quantum parameters trained via gradient descent
|
||||
|
||||
**Key Metrics & Evaluation**:
|
||||
|
||||
**Training Metrics**:
|
||||
- `train/hybrid_loss`: Combined quantum-classical model loss
|
||||
- `train/classical_loss`: Baseline classical-only model loss
|
||||
- `train/quantum_loss`: Quantum-specific loss component
|
||||
- `train/alpha_value`: Mixing parameter (0 = full quantum, 1 = full classical)
|
||||
|
||||
**Evaluation Metrics**:
|
||||
- `eval/hybrid_performance`: Entropy-based comparison of hybrid vs classical predictions
|
||||
- `eval/quantum_weight`: Current quantum contribution (1 - α)
|
||||
- `train/quantum_coherence`: Measure of quantum circuit effectiveness
|
||||
|
||||
**Model Metrics**:
|
||||
- `model/alpha`: Real-time mixing parameter
|
||||
- `model/quantum_contribution`: Percentage of quantum influence
|
||||
|
||||
**Interpretation Guide**:
|
||||
- **Decreasing hybrid_loss**: Model improving at next-word prediction
|
||||
- **Stable alpha_value**: Balanced classical-quantum integration
|
||||
- **High quantum_coherence**: Quantum circuits contributing meaningfully
|
||||
- **hybrid_performance > 0.5**: Quantum enhancement provides benefits
|
||||
|
||||
**Technical Implementation Details**:
|
||||
|
||||
**Quantum Circuit Architecture**:
|
||||
```python
|
||||
# Data encoding
|
||||
qml.RY(classical_data, wires=qubit)
|
||||
|
||||
# Parameterized layers
|
||||
for layer in range(n_layers):
|
||||
for qubit in range(n_qubits):
|
||||
qml.RY(learnable_params[layer, qubit], wires=qubit)
|
||||
|
||||
# Entanglement pattern
|
||||
for i in range(n_qubits - 1):
|
||||
qml.CNOT(wires=[i, i + 1])
|
||||
qml.CNOT(wires=[n_qubits - 1, 0]) # Ring topology
|
||||
|
||||
# Measurement
|
||||
[qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]
|
||||
```
|
||||
|
||||
**Training Process**:
|
||||
1. **Forward Pass**: Hidden states → quantum circuits → predictions
|
||||
2. **Loss Calculation**: Cross-entropy on next-word prediction
|
||||
3. **Backpropagation**: Gradients through quantum circuits via parameter-shift rule
|
||||
4. **Optimization**: Adam optimizer updates both classical and quantum parameters
|
||||
|
||||
**Novel Contributions**:
|
||||
- **First quantum-enhanced Atropos environment**
|
||||
- **Hybrid architecture balancing quantum and classical processing**
|
||||
- **Knowledge distillation from large classical models to small quantum models**
|
||||
- **Quantum-aware evaluation metrics for NLP tasks**
|
||||
|
||||
**Current Limitations**:
|
||||
- **Simulated Quantum**: Uses classical simulation (no quantum hardware)
|
||||
- **Synthetic Features**: Uses random hidden states (not real text embeddings in optimized version)
|
||||
- **Scale**: Limited to 8 qubits due to exponential simulation cost
|
||||
- **Evaluation**: Simple entropy comparison (more sophisticated metrics possible)
|
||||
|
||||
**Potential Applications**:
|
||||
- **Quantum NLP Research**: Differentiable quantum circuits for language tasks
|
||||
- **Hybrid Model Architectures**: Resource-constrained environments with quantum enhancement
|
||||
- **Novel Optimization**: Combining classical and quantum approaches
|
||||
- **Benchmark Creation**: Quantum machine learning evaluation in language tasks
|
||||
|
||||
**Future Research Directions**:
|
||||
|
||||
**Immediate Improvements**:
|
||||
- **Real Text Processing**: Replace synthetic hidden states with actual transformer embeddings
|
||||
- **Advanced Quantum Circuits**: Implement quantum attention mechanisms
|
||||
- **Scaling Studies**: Investigate qubit count vs performance relationships
|
||||
|
||||
**Long-term Goals**:
|
||||
- **Quantum Hardware**: Deploy on IBM Quantum, IonQ, or other quantum computers
|
||||
- **Larger Models**: Scale to 100+ qubit systems when available
|
||||
- **Quantum Advantage**: Identify specific NLP tasks where quantum provides provable benefits
|
||||
- **Production Systems**: Develop practical quantum-enhanced language models
|
||||
|
||||
**Configuration Options**:
|
||||
- **Quantum Parameters**: Configurable qubit count (default: 8) and layer depth (default: 3)
|
||||
- **Training Settings**: Learning rate, batch size, total steps, evaluation frequency
|
||||
- **Model Architecture**: Base model selection, vocabulary size, hidden dimensions
|
||||
- **Hybrid Weighting**: Adjustable balance between classical and quantum contributions
|
||||
- **Dataset Selection**: WikiText variants or custom text datasets
|
||||
|
||||
**Setup Requirements**:
|
||||
1. **PennyLane**: Quantum computing framework
|
||||
2. **PyTorch**: Deep learning and automatic differentiation
|
||||
3. **Transformers**: Tokenization and model utilities
|
||||
4. **Datasets**: HuggingFace dataset loading
|
||||
5. **NumPy**: Numerical computations
|
||||
6. **WandB**: Experiment tracking and visualization
|
||||
|
||||
**Installation & Usage**:
|
||||
```bash
|
||||
# Install quantum dependencies
|
||||
pip install pennylane torch transformers datasets numpy wandb
|
||||
|
||||
# Run optimized hybrid training
|
||||
python environments/community/quantum_hybrid/atropos.py process \
|
||||
--env.n_qubits 8 \
|
||||
--env.n_layers 3 \
|
||||
--env.total_steps 50 \
|
||||
--env.quantum_weight 0.3
|
||||
|
||||
# Run dataset-driven training
|
||||
python environments/community/quantum_hybrid/atopos_quant.py process \
|
||||
--env.dataset_name wikitext \
|
||||
--env.dataset_config wikitext-2-raw-v1 \
|
||||
--env.n_qubits 8
|
||||
```
|
||||
|
||||
**Live Experiment Tracking**: Monitor training progress and quantum metrics at WandB dashboard with real-time visualization of quantum-classical balance and performance metrics.
|
||||
|
||||
**Research Impact**: This environment represents cutting-edge research in quantum machine learning for NLP. While quantum advantages are still under investigation, the framework provides a foundation for future breakthroughs in quantum-enhanced language processing.
|
||||
|
||||
**Repository Structure**:
|
||||
```
|
||||
environments/community/quantum_hybrid/
|
||||
├── atropos.py # Optimized hybrid model implementation
|
||||
├── atopos_quant.py # Dataset-driven quantum training
|
||||
├── requirements.txt # Python dependencies
|
||||
├── README.md # Detailed documentation
|
||||
├── quantum_hybrid_artifacts.tar.gz # Training artifacts
|
||||
└── quantum_latest_artifacts.tar.gz # Latest training data
|
||||
```
|
||||
|
||||
**Requirements**: pennylane, torch, transformers, datasets, numpy, pydantic, atroposlib
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pennylane as qml
|
||||
import torch
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from pydantic import Field
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
|
|
@ -22,54 +18,66 @@ from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|||
|
||||
class QuantumHybridConfig(BaseEnvConfig):
|
||||
"""Configuration for the Quantum-Classical Hybrid Environment."""
|
||||
|
||||
|
||||
# Quantum circuit parameters
|
||||
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
|
||||
n_layers: int = Field(3, description="Number of quantum layers to use in the circuit")
|
||||
|
||||
n_layers: int = Field(
|
||||
3, description="Number of quantum layers to use in the circuit"
|
||||
)
|
||||
|
||||
# Dataset parameters
|
||||
dataset_name: str = Field("wikitext", description="Dataset to use for training/evaluation")
|
||||
dataset_config: str = Field("wikitext-2-raw-v1", description="Dataset configuration")
|
||||
dataset_name: str = Field(
|
||||
"wikitext", description="Dataset to use for training/evaluation"
|
||||
)
|
||||
dataset_config: str = Field(
|
||||
"wikitext-2-raw-v1", description="Dataset configuration"
|
||||
)
|
||||
sequence_length: int = Field(256, description="Length of sequences to process")
|
||||
|
||||
|
||||
# Base model parameters
|
||||
base_model_name: str = Field("gpt2", description="Base model for hybrid experiments")
|
||||
|
||||
base_model_name: str = Field(
|
||||
"gpt2", description="Base model for hybrid experiments"
|
||||
)
|
||||
|
||||
# Training parameters
|
||||
perplexity_weight: float = Field(0.7, description="Weight for perplexity in scoring")
|
||||
quantum_weight: float = Field(0.3, description="Weight for quantum-specific metrics")
|
||||
perplexity_weight: float = Field(
|
||||
0.7, description="Weight for perplexity in scoring"
|
||||
)
|
||||
quantum_weight: float = Field(
|
||||
0.3, description="Weight for quantum-specific metrics"
|
||||
)
|
||||
|
||||
|
||||
class QuantumTextAnalyzer:
|
||||
"""Standalone quantum analyzer for text coherence measurement."""
|
||||
|
||||
|
||||
def __init__(self, n_qubits=6):
|
||||
self.n_qubits = n_qubits
|
||||
self.dev = qml.device("default.qubit", wires=n_qubits)
|
||||
|
||||
|
||||
@qml.qnode(self.dev)
|
||||
def text_analysis_circuit(text_features):
|
||||
# Embed text features as quantum states
|
||||
for i in range(self.n_qubits):
|
||||
qml.RY(text_features[i], wires=i)
|
||||
|
||||
|
||||
# Create entanglement patterns
|
||||
for i in range(self.n_qubits - 1):
|
||||
qml.CNOT(wires=[i, i+1])
|
||||
|
||||
qml.CNOT(wires=[i, i + 1])
|
||||
|
||||
# Ring closure
|
||||
if self.n_qubits > 1:
|
||||
qml.CNOT(wires=[self.n_qubits-1, 0])
|
||||
|
||||
qml.CNOT(wires=[self.n_qubits - 1, 0])
|
||||
|
||||
# Additional entanglement for complex analysis
|
||||
for i in range(0, self.n_qubits-2, 2):
|
||||
qml.CNOT(wires=[i, i+2])
|
||||
|
||||
for i in range(0, self.n_qubits - 2, 2):
|
||||
qml.CNOT(wires=[i, i + 2])
|
||||
|
||||
# Measure coherence
|
||||
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
|
||||
|
||||
|
||||
self.circuit = text_analysis_circuit
|
||||
|
||||
|
||||
def analyze_text(self, text: str) -> float:
|
||||
"""Analyze text and return quantum coherence score."""
|
||||
try:
|
||||
|
|
@ -77,12 +85,18 @@ class QuantumTextAnalyzer:
|
|||
text_len = min(len(text), 200) / 200.0
|
||||
word_count = len(text.split()) / 100.0 if text.split() else 0.0
|
||||
char_diversity = len(set(text.lower())) / 26.0 if text else 0.0
|
||||
avg_word_len = np.mean([len(word) for word in text.split()]) / 15.0 if text.split() else 0.0
|
||||
|
||||
avg_word_len = (
|
||||
np.mean([len(word) for word in text.split()]) / 15.0
|
||||
if text.split()
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# Additional features
|
||||
punctuation_ratio = sum(1 for c in text if c in '.,!?;:') / max(len(text), 1)
|
||||
punctuation_ratio = sum(1 for c in text if c in ".,!?;:") / max(
|
||||
len(text), 1
|
||||
)
|
||||
uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
|
||||
|
||||
|
||||
# Encode as quantum features
|
||||
features = [
|
||||
text_len * np.pi,
|
||||
|
|
@ -90,16 +104,16 @@ class QuantumTextAnalyzer:
|
|||
char_diversity * np.pi,
|
||||
avg_word_len * np.pi,
|
||||
punctuation_ratio * np.pi,
|
||||
uppercase_ratio * np.pi
|
||||
uppercase_ratio * np.pi,
|
||||
]
|
||||
|
||||
|
||||
# Run quantum analysis
|
||||
measurements = self.circuit(features)
|
||||
|
||||
|
||||
# Calculate coherence as normalized measurement variance
|
||||
coherence = np.var(measurements) / (np.var(measurements) + 0.1)
|
||||
return float(np.clip(coherence, 0.0, 1.0))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Quantum text analysis error: {e}")
|
||||
# Fallback to simple heuristic
|
||||
|
|
@ -108,7 +122,7 @@ class QuantumTextAnalyzer:
|
|||
|
||||
class QuantumHybridEnv(BaseEnv):
|
||||
"""Environment for training and evaluating quantum-classical hybrid models."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: QuantumHybridConfig,
|
||||
|
|
@ -124,16 +138,16 @@ class QuantumHybridEnv(BaseEnv):
|
|||
"combined_score": [],
|
||||
"quantum_variance": [],
|
||||
}
|
||||
|
||||
|
||||
# Initialize eval_metrics list (required for BaseEnv)
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Initialize quantum text analyzer
|
||||
self.quantum_analyzer = QuantumTextAnalyzer(n_qubits=config.n_qubits)
|
||||
|
||||
|
||||
# Track training iteration
|
||||
self.iter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[QuantumHybridConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration."""
|
||||
|
|
@ -156,7 +170,7 @@ class QuantumHybridEnv(BaseEnv):
|
|||
perplexity_weight=0.7,
|
||||
quantum_weight=0.3,
|
||||
)
|
||||
|
||||
|
||||
# The server config here tells Atropos to route to your vLLM server
|
||||
# This should match whatever model name the Atropos server expects
|
||||
server_configs = [
|
||||
|
|
@ -173,105 +187,118 @@ class QuantumHybridEnv(BaseEnv):
|
|||
rolling_buffer_length=100,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the environment, including loading datasets."""
|
||||
print(f"Setting up QuantumHybridEnv with quantum parameters: {self.config.n_qubits} qubits, {self.config.n_layers} layers")
|
||||
|
||||
print(
|
||||
f"Setting up QuantumHybridEnv with quantum parameters: "
|
||||
f"{self.config.n_qubits} qubits, {self.config.n_layers} layers"
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
|
||||
# Load dataset
|
||||
try:
|
||||
dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
self.config.dataset_config,
|
||||
split="train",
|
||||
streaming=True
|
||||
self.config.dataset_name,
|
||||
self.config.dataset_config,
|
||||
split="train",
|
||||
streaming=True,
|
||||
)
|
||||
self.dataset = dataset.take(10000) # Take first 10k examples
|
||||
|
||||
|
||||
eval_dataset = load_dataset(
|
||||
self.config.dataset_name,
|
||||
self.config.dataset_config,
|
||||
split="validation",
|
||||
streaming=True
|
||||
streaming=True,
|
||||
)
|
||||
self.eval_dataset = eval_dataset.take(100) # Take first 100 for eval
|
||||
|
||||
print(f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}")
|
||||
|
||||
|
||||
print(
|
||||
f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load dataset: {e}")
|
||||
# Create a mock dataset for testing
|
||||
self.dataset = [{"text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns."} for i in range(1000)]
|
||||
self.eval_dataset = [{"text": f"Eval text {i} with various complexity levels."} for i in range(100)]
|
||||
self.dataset = [
|
||||
{
|
||||
"text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns."
|
||||
}
|
||||
for i in range(1000)
|
||||
]
|
||||
self.eval_dataset = [
|
||||
{"text": f"Eval text {i} with various complexity levels."}
|
||||
for i in range(100)
|
||||
]
|
||||
print("Using mock dataset")
|
||||
|
||||
|
||||
# Convert to lists for easier iteration
|
||||
if hasattr(self.dataset, '__iter__'):
|
||||
if hasattr(self.dataset, "__iter__"):
|
||||
self.train_examples = list(self.dataset)
|
||||
else:
|
||||
self.train_examples = self.dataset
|
||||
|
||||
if hasattr(self.eval_dataset, '__iter__'):
|
||||
|
||||
if hasattr(self.eval_dataset, "__iter__"):
|
||||
self.eval_examples = list(self.eval_dataset)
|
||||
else:
|
||||
self.eval_examples = self.eval_dataset
|
||||
|
||||
|
||||
print(f"Loaded {len(self.train_examples)} training examples")
|
||||
print(f"Loaded {len(self.eval_examples)} evaluation examples")
|
||||
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Get the next training item from the dataset."""
|
||||
# Get next instance from the dataset
|
||||
if self.iter >= len(self.train_examples):
|
||||
self.iter = 0 # Reset to beginning
|
||||
|
||||
|
||||
data_point = self.train_examples[self.iter]
|
||||
self.iter += 1
|
||||
|
||||
|
||||
# Process text data
|
||||
if isinstance(data_point, dict) and "text" in data_point:
|
||||
text = data_point["text"]
|
||||
else:
|
||||
text = str(data_point)
|
||||
|
||||
|
||||
# Truncate text to reasonable length
|
||||
text = text[:self.config.sequence_length * 4] # Allow for some context
|
||||
|
||||
text = text[: self.config.sequence_length * 4] # Allow for some context
|
||||
|
||||
# Create a simple continuation task
|
||||
words = text.split()
|
||||
if len(words) > 10:
|
||||
# Split at a random point for continuation
|
||||
split_idx = random.randint(5, min(len(words) - 5, 20))
|
||||
prompt_text = ' '.join(words[:split_idx])
|
||||
target_text = ' '.join(words[split_idx:split_idx + 20])
|
||||
prompt_text = " ".join(words[:split_idx])
|
||||
target_text = " ".join(words[split_idx : split_idx + 20])
|
||||
else:
|
||||
prompt_text = text[:len(text)//2]
|
||||
target_text = text[len(text)//2:]
|
||||
|
||||
prompt_text = text[: len(text) // 2]
|
||||
target_text = text[len(text) // 2 :]
|
||||
|
||||
# Convert to messages format for Atropos
|
||||
user_msg = {"role": "user", "content": f"Continue this text: {prompt_text}"}
|
||||
|
||||
|
||||
# Return as (messages, target) tuple
|
||||
return ([user_msg], target_text)
|
||||
|
||||
|
||||
async def collect_trajectories(self, item: Tuple) -> Tuple[ScoredDataGroup, List]:
|
||||
"""Generate and collect model responses for scoring."""
|
||||
messages, target_text = item
|
||||
|
||||
|
||||
print(f"Generating completions for: {messages[0]['content'][:50]}...")
|
||||
|
||||
|
||||
to_score = []
|
||||
|
||||
|
||||
# Generate multiple completions with different temperatures
|
||||
temperatures = [0.6, 0.8, 1.0, 1.2][:self.config.group_size]
|
||||
|
||||
temperatures = [0.6, 0.8, 1.0, 1.2][: self.config.group_size]
|
||||
|
||||
for i, temp in enumerate(temperatures):
|
||||
try:
|
||||
# Use the Atropos server to generate completions
|
||||
|
|
@ -281,50 +308,67 @@ class QuantumHybridEnv(BaseEnv):
|
|||
max_tokens=min(100, self.config.max_token_length // 4),
|
||||
temperature=temp,
|
||||
)
|
||||
|
||||
|
||||
completion_text = completion.choices[0].text.strip()
|
||||
print(f"Generated completion {i+1} (T={temp}): {completion_text[:50]}...")
|
||||
|
||||
print(
|
||||
f"Generated completion {i+1} (T={temp}): {completion_text[:50]}..."
|
||||
)
|
||||
|
||||
# Create trajectory messages
|
||||
trajectory_messages = messages.copy()
|
||||
trajectory_messages.append(
|
||||
{"role": "assistant", "content": completion_text}
|
||||
)
|
||||
|
||||
|
||||
# Add to scoring queue
|
||||
to_score.append((tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, completion_text))
|
||||
|
||||
to_score.append(
|
||||
(
|
||||
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
|
||||
target_text,
|
||||
completion_text,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating completion {i+1} with temperature {temp}: {e}")
|
||||
# Create a mock completion as fallback
|
||||
mock_text = f"Mock quantum-enhanced response {i+1}: This demonstrates coherent language generation with temperature {temp}."
|
||||
mock_text = (
|
||||
f"Mock quantum-enhanced response {i+1}: This demonstrates "
|
||||
f"coherent language generation with temperature {temp}."
|
||||
)
|
||||
trajectory_messages = messages.copy()
|
||||
trajectory_messages.append({"role": "assistant", "content": mock_text})
|
||||
to_score.append((tuple([frozenset(msg.items()) for msg in trajectory_messages]), target_text, mock_text))
|
||||
|
||||
to_score.append(
|
||||
(
|
||||
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
|
||||
target_text,
|
||||
mock_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Score the generated trajectories
|
||||
scored_data = await self.score(to_score)
|
||||
|
||||
|
||||
return scored_data, []
|
||||
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
"""Score model outputs based on perplexity and quantum metrics."""
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
||||
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
|
||||
|
||||
print(f"Scoring {len(rollout_group_data)} completions...")
|
||||
|
||||
|
||||
for item in rollout_group_data:
|
||||
frozen_messages, target_text, completion_text = item
|
||||
|
||||
|
||||
# Convert frozen messages back to regular format
|
||||
messages = [dict(frozen_msg) for frozen_msg in frozen_messages]
|
||||
|
||||
|
||||
# Convert messages to tokens
|
||||
try:
|
||||
tokenized = tokenize_for_trainer(self.tokenizer, messages)
|
||||
|
|
@ -333,11 +377,11 @@ class QuantumHybridEnv(BaseEnv):
|
|||
except Exception as e:
|
||||
print(f"Tokenization error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Calculate text similarity score (proxy for perplexity)
|
||||
completion_words = set(completion_text.lower().split())
|
||||
target_words = set(target_text.lower().split())
|
||||
|
||||
|
||||
# Jaccard similarity
|
||||
if completion_words and target_words:
|
||||
intersection = len(completion_words & target_words)
|
||||
|
|
@ -345,50 +389,55 @@ class QuantumHybridEnv(BaseEnv):
|
|||
similarity_score = intersection / union if union > 0 else 0.0
|
||||
else:
|
||||
similarity_score = 0.0
|
||||
|
||||
|
||||
# Length penalty/bonus
|
||||
len_ratio = len(completion_text) / max(len(target_text), 1)
|
||||
len_score = 1.0 - abs(1.0 - len_ratio) if len_ratio > 0 else 0.0
|
||||
|
||||
|
||||
# Combined perplexity-like score
|
||||
perplexity_score = 0.7 * similarity_score + 0.3 * len_score
|
||||
|
||||
|
||||
# Quantum coherence analysis
|
||||
quantum_coherence = self.quantum_analyzer.analyze_text(completion_text)
|
||||
|
||||
|
||||
# Calculate quantum variance for additional insight
|
||||
quantum_variance = abs(quantum_coherence - 0.5) * 2 # Distance from maximum entropy
|
||||
|
||||
quantum_variance = (
|
||||
abs(quantum_coherence - 0.5) * 2
|
||||
) # Distance from maximum entropy
|
||||
|
||||
# Combined score using weighted sum
|
||||
combined_score = (
|
||||
self.config.perplexity_weight * perplexity_score +
|
||||
self.config.quantum_weight * quantum_coherence
|
||||
self.config.perplexity_weight * perplexity_score
|
||||
+ self.config.quantum_weight * quantum_coherence
|
||||
)
|
||||
|
||||
print(f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, Combined: {combined_score:.3f}")
|
||||
|
||||
|
||||
print(
|
||||
f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, "
|
||||
f"Combined: {combined_score:.3f}"
|
||||
)
|
||||
|
||||
# Update metrics buffer
|
||||
self.metrics_buffer["perplexity"].append(perplexity_score)
|
||||
self.metrics_buffer["quantum_coherence"].append(quantum_coherence)
|
||||
self.metrics_buffer["combined_score"].append(combined_score)
|
||||
self.metrics_buffer["quantum_variance"].append(quantum_variance)
|
||||
|
||||
|
||||
# Store data for training (scale to [-1, 1] range)
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(2 * combined_score - 1)
|
||||
|
||||
|
||||
return scores if scores["scores"] else None
|
||||
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate the model on a test set."""
|
||||
print("Running quantum-enhanced evaluation...")
|
||||
eval_scores = []
|
||||
quantum_scores = []
|
||||
|
||||
|
||||
# Get evaluation examples
|
||||
eval_examples = self.eval_examples[:min(5, len(self.eval_examples))]
|
||||
|
||||
eval_examples = self.eval_examples[: min(5, len(self.eval_examples))]
|
||||
|
||||
for example in eval_examples:
|
||||
try:
|
||||
# Process text
|
||||
|
|
@ -396,20 +445,22 @@ class QuantumHybridEnv(BaseEnv):
|
|||
text = example["text"]
|
||||
else:
|
||||
text = str(example)
|
||||
|
||||
|
||||
# Create continuation task
|
||||
words = text.split()
|
||||
if len(words) > 8:
|
||||
split_idx = len(words) // 2
|
||||
prompt_text = ' '.join(words[:split_idx])
|
||||
target_text = ' '.join(words[split_idx:])
|
||||
prompt_text = " ".join(words[:split_idx])
|
||||
target_text = " ".join(words[split_idx:])
|
||||
else:
|
||||
prompt_text = text[:len(text)//2]
|
||||
target_text = text[len(text)//2:]
|
||||
|
||||
prompt_text = text[: len(text) // 2]
|
||||
target_text = text[len(text) // 2 :]
|
||||
|
||||
# Create messages
|
||||
messages = [{"role": "user", "content": f"Continue this text: {prompt_text}"}]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"Continue this text: {prompt_text}"}
|
||||
]
|
||||
|
||||
# Generate completion
|
||||
completion = await self.server.completion(
|
||||
prompt=self.tokenizer.apply_chat_template(messages, tokenize=False),
|
||||
|
|
@ -418,53 +469,59 @@ class QuantumHybridEnv(BaseEnv):
|
|||
temperature=0.7,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
|
||||
completion_text = completion.choices[0].text.strip()
|
||||
|
||||
|
||||
# Calculate metrics
|
||||
completion_words = set(completion_text.lower().split())
|
||||
target_words = set(target_text.lower().split())
|
||||
|
||||
|
||||
if completion_words and target_words:
|
||||
intersection = len(completion_words & target_words)
|
||||
union = len(completion_words | target_words)
|
||||
similarity = intersection / union if union > 0 else 0.0
|
||||
else:
|
||||
similarity = 0.0
|
||||
|
||||
|
||||
# Quantum analysis
|
||||
quantum_coherence = self.quantum_analyzer.analyze_text(completion_text)
|
||||
|
||||
|
||||
# Combined evaluation score
|
||||
eval_score = (
|
||||
self.config.perplexity_weight * similarity +
|
||||
self.config.quantum_weight * quantum_coherence
|
||||
self.config.perplexity_weight * similarity
|
||||
+ self.config.quantum_weight * quantum_coherence
|
||||
)
|
||||
|
||||
|
||||
eval_scores.append(eval_score)
|
||||
quantum_scores.append(quantum_coherence)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Evaluation error: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Log evaluation metrics
|
||||
if eval_scores:
|
||||
avg_eval_score = sum(eval_scores) / len(eval_scores)
|
||||
avg_quantum_score = sum(quantum_scores) / len(quantum_scores)
|
||||
|
||||
|
||||
self.eval_metrics.append(("eval/combined_score", avg_eval_score))
|
||||
self.eval_metrics.append(("eval/quantum_coherence", avg_quantum_score))
|
||||
self.eval_metrics.append(("eval/perplexity_weight", self.config.perplexity_weight))
|
||||
self.eval_metrics.append(("eval/quantum_weight", self.config.quantum_weight))
|
||||
|
||||
print(f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}")
|
||||
|
||||
self.eval_metrics.append(
|
||||
("eval/perplexity_weight", self.config.perplexity_weight)
|
||||
)
|
||||
self.eval_metrics.append(
|
||||
("eval/quantum_weight", self.config.quantum_weight)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}"
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics to Weights & Biases."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
|
||||
# Calculate and add metrics from buffer
|
||||
for metric_name, values in self.metrics_buffer.items():
|
||||
if values:
|
||||
|
|
@ -472,22 +529,22 @@ class QuantumHybridEnv(BaseEnv):
|
|||
wandb_metrics[f"train/{metric_name}_max"] = max(values)
|
||||
wandb_metrics[f"train/{metric_name}_min"] = min(values)
|
||||
wandb_metrics[f"train/{metric_name}_std"] = np.std(values)
|
||||
|
||||
|
||||
# Add quantum-specific metrics
|
||||
wandb_metrics["quantum/n_qubits"] = self.config.n_qubits
|
||||
wandb_metrics["quantum/n_layers"] = self.config.n_layers
|
||||
wandb_metrics["quantum/weight"] = self.config.quantum_weight
|
||||
wandb_metrics["train/perplexity_weight"] = self.config.perplexity_weight
|
||||
|
||||
|
||||
# Clear the buffer
|
||||
self.metrics_buffer = {key: [] for key in self.metrics_buffer}
|
||||
|
||||
|
||||
# Add evaluation metrics
|
||||
for name, value in self.eval_metrics:
|
||||
wandb_metrics[name] = value
|
||||
|
||||
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Log to wandb using the parent method
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
|
@ -1,16 +1,11 @@
|
|||
import asyncio
|
||||
import itertools
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pennylane as qml
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import wandb
|
||||
from datasets import load_dataset
|
||||
from pydantic import Field
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
|
|
@ -18,160 +13,177 @@ from atroposlib.envs.base import (
|
|||
BaseEnvConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
||||
class QuantumHybridConfig(BaseEnvConfig):
|
||||
"""Configuration for the Quantum-Classical Hybrid Environment."""
|
||||
|
||||
|
||||
# Quantum circuit parameters
|
||||
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
|
||||
n_layers: int = Field(3, description="Number of quantum layers to use in the circuit")
|
||||
|
||||
n_layers: int = Field(
|
||||
3, description="Number of quantum layers to use in the circuit"
|
||||
)
|
||||
|
||||
# Dataset parameters
|
||||
dataset_name: str = Field("wikitext", description="Dataset to use for training/evaluation")
|
||||
dataset_config: str = Field("wikitext-2-raw-v1", description="Dataset configuration")
|
||||
dataset_name: str = Field(
|
||||
"wikitext", description="Dataset to use for training/evaluation"
|
||||
)
|
||||
dataset_config: str = Field(
|
||||
"wikitext-2-raw-v1", description="Dataset configuration"
|
||||
)
|
||||
sequence_length: int = Field(256, description="Length of sequences to process")
|
||||
|
||||
|
||||
# Base model parameters
|
||||
base_model_name: str = Field("gpt2", description="Base model for hybrid experiments")
|
||||
|
||||
base_model_name: str = Field(
|
||||
"gpt2", description="Base model for hybrid experiments"
|
||||
)
|
||||
|
||||
# Training parameters
|
||||
eval_every: int = Field(10, description="Evaluate every N training steps")
|
||||
perplexity_weight: float = Field(0.7, description="Weight for perplexity in scoring")
|
||||
quantum_weight: float = Field(0.3, description="Weight for quantum-specific metrics")
|
||||
|
||||
perplexity_weight: float = Field(
|
||||
0.7, description="Weight for perplexity in scoring"
|
||||
)
|
||||
quantum_weight: float = Field(
|
||||
0.3, description="Weight for quantum-specific metrics"
|
||||
)
|
||||
|
||||
# Hybrid model parameters
|
||||
train_hybrid_model: bool = Field(True, description="Whether to train the hybrid model")
|
||||
learning_rate: float = Field(1e-4, description="Learning rate for quantum parameters")
|
||||
|
||||
train_hybrid_model: bool = Field(
|
||||
True, description="Whether to train the hybrid model"
|
||||
)
|
||||
learning_rate: float = Field(
|
||||
1e-4, description="Learning rate for quantum parameters"
|
||||
)
|
||||
|
||||
# Comparison parameters
|
||||
compare_with_classical: bool = Field(True, description="Compare hybrid with classical model")
|
||||
compare_with_classical: bool = Field(
|
||||
True, description="Compare hybrid with classical model"
|
||||
)
|
||||
|
||||
|
||||
class OptimizedQuantumLayer(torch.nn.Module):
|
||||
"""Quantum circuit layer implementation using PennyLane."""
|
||||
|
||||
|
||||
def __init__(self, n_qubits=8, n_layers=3):
|
||||
super().__init__()
|
||||
self.n_qubits = n_qubits
|
||||
self.n_layers = n_layers
|
||||
|
||||
|
||||
# Create a quantum device with the specified number of qubits
|
||||
self.dev = qml.device("default.qubit", wires=n_qubits)
|
||||
|
||||
|
||||
# Initialize quantum circuit parameters
|
||||
self.params = torch.nn.Parameter(torch.randn(n_layers, n_qubits) * 0.1)
|
||||
|
||||
|
||||
# Define the quantum circuit
|
||||
@qml.qnode(self.dev, interface="torch")
|
||||
def circuit(inputs, params):
|
||||
# Embed classical data as quantum states
|
||||
for i in range(self.n_qubits):
|
||||
qml.RY(inputs[i], wires=i)
|
||||
|
||||
|
||||
# Apply parameterized quantum layers
|
||||
for l in range(self.n_layers):
|
||||
for layer in range(self.n_layers):
|
||||
# Rotation gates with learnable parameters
|
||||
for i in range(self.n_qubits):
|
||||
qml.RY(params[l, i], wires=i)
|
||||
|
||||
qml.RY(params[layer, i], wires=i)
|
||||
|
||||
# Entanglement between qubits
|
||||
for i in range(self.n_qubits - 1):
|
||||
qml.CNOT(wires=[i, i + 1])
|
||||
|
||||
|
||||
# Special case: connect last qubit to first qubit for full entanglement
|
||||
if self.n_qubits > 1:
|
||||
qml.CNOT(wires=[self.n_qubits - 1, 0])
|
||||
|
||||
|
||||
# Measure all qubits in the computational basis
|
||||
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
|
||||
|
||||
|
||||
self.circuit = circuit
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# Handle single item or batch
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
|
||||
batch_size = x.shape[0]
|
||||
results = []
|
||||
|
||||
|
||||
# Process each item in the batch
|
||||
for i in range(batch_size):
|
||||
# Normalize input values to [-1, 1] for quantum embedding
|
||||
x_norm = torch.tanh(x[i, :self.n_qubits])
|
||||
x_norm = torch.tanh(x[i, : self.n_qubits])
|
||||
# Get quantum circuit output
|
||||
try:
|
||||
quantum_out = torch.tensor(self.circuit(x_norm, self.params), dtype=torch.float32)
|
||||
quantum_out = torch.tensor(
|
||||
self.circuit(x_norm, self.params), dtype=torch.float32
|
||||
)
|
||||
results.append(quantum_out)
|
||||
except Exception as e:
|
||||
print(f"Quantum circuit error: {e}")
|
||||
# Fallback to random values if quantum circuit fails
|
||||
results.append(torch.randn(self.n_qubits))
|
||||
|
||||
|
||||
return torch.stack(results)
|
||||
|
||||
|
||||
class OptimizedHybridModel(torch.nn.Module):
|
||||
"""Hybrid model combining classical transformer model with quantum layers."""
|
||||
|
||||
|
||||
def __init__(self, base_model_name, n_qubits=8, n_layers=3, vocab_size=50257):
|
||||
super().__init__()
|
||||
|
||||
|
||||
# We'll simulate the classical model behavior instead of loading full model
|
||||
# to avoid memory issues in this environment
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = 768 # GPT2 hidden size
|
||||
|
||||
|
||||
# Dimensionality reduction to quantum space
|
||||
self.classical_to_quantum = torch.nn.Linear(self.hidden_size, n_qubits)
|
||||
|
||||
|
||||
# Quantum circuit layers
|
||||
self.quantum_layer1 = OptimizedQuantumLayer(n_qubits, n_layers)
|
||||
self.quantum_layer2 = OptimizedQuantumLayer(n_qubits, n_layers)
|
||||
|
||||
|
||||
# Map quantum output back to vocabulary space
|
||||
self.quantum_to_logits = torch.nn.Linear(n_qubits, self.vocab_size)
|
||||
|
||||
|
||||
# Mixing parameter
|
||||
self.alpha = torch.nn.Parameter(torch.tensor([0.5]))
|
||||
|
||||
|
||||
# Simple classical head for baseline comparison
|
||||
self.classical_head = torch.nn.Linear(self.hidden_size, self.vocab_size)
|
||||
|
||||
|
||||
def forward(self, hidden_states, return_classical=False):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch_size, hidden_size] tensor
|
||||
return_classical: if True, return classical logits for comparison
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# Classical pathway
|
||||
classical_logits = self.classical_head(hidden_states)
|
||||
|
||||
|
||||
if return_classical:
|
||||
return classical_logits
|
||||
|
||||
|
||||
# Quantum pathway
|
||||
try:
|
||||
# Reduce dimensionality for quantum processing
|
||||
quantum_input = self.classical_to_quantum(hidden_states)
|
||||
|
||||
|
||||
# Process through quantum circuits
|
||||
quantum_output1 = self.quantum_layer1(quantum_input)
|
||||
quantum_output2 = self.quantum_layer2(quantum_output1)
|
||||
|
||||
|
||||
# Convert to vocabulary space
|
||||
quantum_logits = self.quantum_to_logits(quantum_output2)
|
||||
|
||||
|
||||
# Combine classical and quantum predictions
|
||||
alpha = torch.sigmoid(self.alpha)
|
||||
combined_logits = alpha * classical_logits + (1 - alpha) * quantum_logits
|
||||
|
||||
|
||||
return combined_logits
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Quantum forward pass error: {e}, using classical only")
|
||||
return classical_logits
|
||||
|
|
@ -179,7 +191,7 @@ class OptimizedHybridModel(torch.nn.Module):
|
|||
|
||||
class QuantumHybridEnv(BaseEnv):
|
||||
"""Environment for training and evaluating quantum-classical hybrid models."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: QuantumHybridConfig,
|
||||
|
|
@ -198,11 +210,11 @@ class QuantumHybridEnv(BaseEnv):
|
|||
"quantum_loss": [],
|
||||
"alpha_value": [],
|
||||
}
|
||||
|
||||
|
||||
# Initialize models
|
||||
self.hybrid_model = None
|
||||
self.optimizer = None
|
||||
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[QuantumHybridConfig, List[APIServerConfig]]:
|
||||
"""Initialize default configuration."""
|
||||
|
|
@ -225,7 +237,7 @@ class QuantumHybridEnv(BaseEnv):
|
|||
train_hybrid_model=True,
|
||||
compare_with_classical=True,
|
||||
)
|
||||
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="NousResearch/Hermes-3-Llama-3.1-70B",
|
||||
|
|
@ -236,42 +248,44 @@ class QuantumHybridEnv(BaseEnv):
|
|||
num_requests_for_eval=8,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
return config, server_configs
|
||||
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the environment and initialize models."""
|
||||
# Use synthetic data to avoid HuggingFace issues
|
||||
print("Setting up quantum-hybrid model training environment...")
|
||||
|
||||
|
||||
# Simple tokenizer
|
||||
class SimpleTokenizer:
|
||||
def __init__(self):
|
||||
self.vocab_size = 1000
|
||||
self.pad_token = "[PAD]"
|
||||
self.eos_token = "[EOS]"
|
||||
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
# Simple word-based tokenization
|
||||
words = text.split()[:50]
|
||||
return [hash(word) % self.vocab_size for word in words]
|
||||
|
||||
|
||||
def apply_chat_template(self, messages, **kwargs):
|
||||
return f"User: {messages[0]['content']}\nAssistant:"
|
||||
|
||||
|
||||
self.tokenizer = SimpleTokenizer()
|
||||
|
||||
|
||||
# Initialize hybrid model
|
||||
self.hybrid_model = OptimizedHybridModel(
|
||||
base_model_name=self.config.base_model_name,
|
||||
n_qubits=self.config.n_qubits,
|
||||
n_layers=self.config.n_layers,
|
||||
vocab_size=self.tokenizer.vocab_size
|
||||
vocab_size=self.tokenizer.vocab_size,
|
||||
)
|
||||
|
||||
|
||||
# Initialize optimizer for quantum parameters
|
||||
self.optimizer = torch.optim.Adam(self.hybrid_model.parameters(), lr=self.config.learning_rate)
|
||||
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self.hybrid_model.parameters(), lr=self.config.learning_rate
|
||||
)
|
||||
|
||||
# Sample texts for training
|
||||
self.sample_texts = [
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
|
|
@ -283,38 +297,38 @@ class QuantumHybridEnv(BaseEnv):
|
|||
"Artificial intelligence mimics human cognitive functions.",
|
||||
"Computer vision enables machines to interpret images.",
|
||||
"Robotics integrates mechanical and software engineering.",
|
||||
"Data science extracts insights from large datasets."
|
||||
"Data science extracts insights from large datasets.",
|
||||
]
|
||||
|
||||
|
||||
self.iter = 0
|
||||
print("Setup complete! Ready to train quantum-hybrid models.")
|
||||
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Get the next training item."""
|
||||
import random
|
||||
|
||||
|
||||
# Select random text
|
||||
text = random.choice(self.sample_texts)
|
||||
text = f"Example {self.iter + 1}: {text}"
|
||||
self.iter += 1
|
||||
|
||||
|
||||
# Create target for next-word prediction
|
||||
tokens = self.tokenizer.encode(text)
|
||||
|
||||
|
||||
# Convert to messages
|
||||
user_msg = {"role": "user", "content": text}
|
||||
prompt = tuple([frozenset(user_msg.items())])
|
||||
|
||||
|
||||
return (prompt, tokens)
|
||||
|
||||
|
||||
async def collect_trajectories(self, item: Tuple) -> Tuple[ScoredDataGroup, List]:
|
||||
"""Generate responses and train hybrid model."""
|
||||
prompt, target_tokens = item
|
||||
user_content = dict(prompt[0])["content"]
|
||||
|
||||
|
||||
# Generate from external model (Hermes-3-70B)
|
||||
messages = [{"role": "user", "content": user_content}]
|
||||
|
||||
|
||||
try:
|
||||
prompt_text = f"User: {user_content}\nAssistant:"
|
||||
completions = await self.server.completion(
|
||||
|
|
@ -326,190 +340,218 @@ class QuantumHybridEnv(BaseEnv):
|
|||
except Exception as e:
|
||||
print(f"API error: {e}, using fallback responses")
|
||||
# Fallback responses
|
||||
completions = type('obj', (object,), {
|
||||
'choices': [
|
||||
type('choice', (object,), {'text': f"This is response {i+1} to: {user_content[:50]}..."})()
|
||||
for i in range(self.config.group_size)
|
||||
]
|
||||
})()
|
||||
|
||||
completions = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"choices": [
|
||||
type(
|
||||
"choice",
|
||||
(object,),
|
||||
{
|
||||
"text": f"This is response {i+1} to: {user_content[:50]}..."
|
||||
},
|
||||
)()
|
||||
for i in range(self.config.group_size)
|
||||
]
|
||||
},
|
||||
)()
|
||||
|
||||
to_score = []
|
||||
|
||||
|
||||
# Train hybrid model on each response
|
||||
for completion in completions.choices:
|
||||
completion_text = completion.text
|
||||
|
||||
|
||||
# Create trajectory
|
||||
trajectory_messages = messages.copy()
|
||||
trajectory_messages.append({"role": "assistant", "content": completion_text})
|
||||
|
||||
trajectory_messages.append(
|
||||
{"role": "assistant", "content": completion_text}
|
||||
)
|
||||
|
||||
# Train hybrid model
|
||||
if self.config.train_hybrid_model:
|
||||
await self._train_hybrid_model(completion_text, target_tokens)
|
||||
|
||||
|
||||
to_score.append((tuple(trajectory_messages), target_tokens))
|
||||
|
||||
|
||||
# Score the results
|
||||
scored_data = await self.score(to_score)
|
||||
|
||||
|
||||
return scored_data, []
|
||||
|
||||
|
||||
async def _train_hybrid_model(self, generated_text: str, target_tokens: List[int]):
|
||||
"""Train the hybrid model on generated text."""
|
||||
try:
|
||||
# Create synthetic hidden states (simulate transformer encoder output)
|
||||
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
||||
|
||||
|
||||
# Get predictions from hybrid and classical models
|
||||
hybrid_logits = self.hybrid_model(hidden_states)
|
||||
classical_logits = self.hybrid_model(hidden_states, return_classical=True)
|
||||
|
||||
|
||||
# Create targets (simplified - use first few target tokens)
|
||||
max_tokens = min(len(target_tokens), 10)
|
||||
targets = torch.tensor(target_tokens[:max_tokens])
|
||||
|
||||
|
||||
# Calculate losses
|
||||
if max_tokens > 0:
|
||||
# Repeat logits for each target token
|
||||
hybrid_logits_expanded = hybrid_logits.repeat(max_tokens, 1)
|
||||
classical_logits_expanded = classical_logits.repeat(max_tokens, 1)
|
||||
|
||||
|
||||
hybrid_loss = F.cross_entropy(hybrid_logits_expanded, targets)
|
||||
classical_loss = F.cross_entropy(classical_logits_expanded, targets)
|
||||
|
||||
|
||||
# Optimize hybrid model
|
||||
self.optimizer.zero_grad()
|
||||
hybrid_loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
|
||||
# Log metrics
|
||||
self.metrics_buffer["hybrid_loss"].append(hybrid_loss.item())
|
||||
self.metrics_buffer["classical_loss"].append(classical_loss.item())
|
||||
self.metrics_buffer["alpha_value"].append(torch.sigmoid(self.hybrid_model.alpha).item())
|
||||
|
||||
self.metrics_buffer["alpha_value"].append(
|
||||
torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Training error: {e}")
|
||||
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
"""Score responses using quantum-enhanced metrics."""
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = []
|
||||
scores["masks"] = []
|
||||
scores["scores"] = []
|
||||
|
||||
|
||||
for item in rollout_group_data:
|
||||
messages, targets = item
|
||||
|
||||
|
||||
# Simple tokenization
|
||||
generated_text = messages[-1]["content"]
|
||||
tokens = self.tokenizer.encode(generated_text)
|
||||
|
||||
|
||||
# Pad tokens to consistent length
|
||||
max_len = 64
|
||||
if len(tokens) < max_len:
|
||||
tokens.extend([0] * (max_len - len(tokens)))
|
||||
tokens = tokens[:max_len]
|
||||
|
||||
|
||||
# Create masks (all ones for simplicity)
|
||||
masks = [1] * len(tokens)
|
||||
|
||||
|
||||
# Calculate scores
|
||||
text_quality_score = min(len(generated_text) / 100, 1.0)
|
||||
|
||||
|
||||
# Quantum coherence from hybrid model
|
||||
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
|
||||
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
||||
try:
|
||||
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
||||
with torch.no_grad():
|
||||
hybrid_logits = self.hybrid_model(hidden_states)
|
||||
quantum_contribution = 1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
self.hybrid_model(hidden_states) # Run forward pass
|
||||
quantum_contribution = (
|
||||
1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
)
|
||||
quantum_score = quantum_contribution + 0.3 * random.random()
|
||||
except:
|
||||
except Exception:
|
||||
quantum_score = 0.5 + 0.3 * random.random()
|
||||
else:
|
||||
quantum_score = 0.5 + 0.3 * random.random()
|
||||
|
||||
|
||||
# Combined score
|
||||
combined_score = (
|
||||
self.config.perplexity_weight * text_quality_score +
|
||||
self.config.quantum_weight * quantum_score
|
||||
self.config.perplexity_weight * text_quality_score
|
||||
+ self.config.quantum_weight * quantum_score
|
||||
)
|
||||
|
||||
|
||||
# Update metrics
|
||||
self.metrics_buffer["perplexity"].append(text_quality_score)
|
||||
self.metrics_buffer["quantum_coherence"].append(quantum_score)
|
||||
self.metrics_buffer["combined_score"].append(combined_score)
|
||||
|
||||
|
||||
# Store results
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
scores["scores"].append(2 * combined_score - 1) # Scale to [-1, 1]
|
||||
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate the hybrid model."""
|
||||
eval_scores = []
|
||||
|
||||
|
||||
# Test on sample texts
|
||||
for text in self.sample_texts[:5]:
|
||||
try:
|
||||
# Test hybrid model performance
|
||||
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
|
||||
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
||||
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
hybrid_logits = self.hybrid_model(hidden_states)
|
||||
classical_logits = self.hybrid_model(hidden_states, return_classical=True)
|
||||
|
||||
classical_logits = self.hybrid_model(
|
||||
hidden_states, return_classical=True
|
||||
)
|
||||
|
||||
# Compare distributions
|
||||
hybrid_entropy = -torch.sum(F.softmax(hybrid_logits, dim=-1) * F.log_softmax(hybrid_logits, dim=-1))
|
||||
classical_entropy = -torch.sum(F.softmax(classical_logits, dim=-1) * F.log_softmax(classical_logits, dim=-1))
|
||||
|
||||
hybrid_entropy = -torch.sum(
|
||||
F.softmax(hybrid_logits, dim=-1)
|
||||
* F.log_softmax(hybrid_logits, dim=-1)
|
||||
)
|
||||
classical_entropy = -torch.sum(
|
||||
F.softmax(classical_logits, dim=-1)
|
||||
* F.log_softmax(classical_logits, dim=-1)
|
||||
)
|
||||
|
||||
# Score based on entropy difference
|
||||
score = 1.0 - abs(hybrid_entropy - classical_entropy) / max(hybrid_entropy, classical_entropy)
|
||||
score = 1.0 - abs(hybrid_entropy - classical_entropy) / max(
|
||||
hybrid_entropy, classical_entropy
|
||||
)
|
||||
eval_scores.append(score.item())
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Evaluation error: {e}")
|
||||
eval_scores.append(0.5)
|
||||
|
||||
|
||||
if eval_scores:
|
||||
avg_score = sum(eval_scores) / len(eval_scores)
|
||||
self.eval_metrics.append(("eval/hybrid_performance", avg_score))
|
||||
|
||||
|
||||
# Log current alpha value
|
||||
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
|
||||
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
||||
alpha_val = torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
self.eval_metrics.append(("eval/alpha_value", alpha_val))
|
||||
self.eval_metrics.append(("eval/quantum_weight", 1 - alpha_val))
|
||||
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics to Weights & Biases."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
|
||||
# Add buffered metrics
|
||||
for metric_name, values in self.metrics_buffer.items():
|
||||
if values:
|
||||
wandb_metrics[f"train/{metric_name}"] = sum(values) / len(values)
|
||||
if metric_name == "hybrid_loss" and len(values) > 1:
|
||||
wandb_metrics[f"train/{metric_name}_std"] = np.std(values)
|
||||
|
||||
|
||||
# Clear buffer
|
||||
self.metrics_buffer = {key: [] for key in self.metrics_buffer}
|
||||
|
||||
|
||||
# Add evaluation metrics
|
||||
for name, value in self.eval_metrics:
|
||||
wandb_metrics[name] = value
|
||||
|
||||
|
||||
self.eval_metrics = []
|
||||
|
||||
|
||||
# Log hybrid model parameters
|
||||
if hasattr(self, 'hybrid_model') and self.hybrid_model is not None:
|
||||
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
||||
wandb_metrics["model/alpha"] = torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
wandb_metrics["model/quantum_contribution"] = 1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
|
||||
wandb_metrics["model/quantum_contribution"] = (
|
||||
1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
||||
)
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue