mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
559 lines
20 KiB
Python
559 lines
20 KiB
Python
import random
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import pennylane as qml
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
|
|
|
|
class QuantumHybridConfig(BaseEnvConfig):
|
|
"""Configuration for the Quantum-Classical Hybrid Environment."""
|
|
|
|
# Quantum circuit parameters
|
|
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
|
|
n_layers: int = Field(
|
|
3, description="Number of quantum layers to use in the circuit"
|
|
)
|
|
|
|
# Dataset parameters
|
|
dataset_name: str = Field(
|
|
"wikitext", description="Dataset to use for training/evaluation"
|
|
)
|
|
dataset_config: str = Field(
|
|
"wikitext-2-raw-v1", description="Dataset configuration"
|
|
)
|
|
sequence_length: int = Field(256, description="Length of sequences to process")
|
|
|
|
# Base model parameters
|
|
base_model_name: str = Field(
|
|
"gpt2", description="Base model for hybrid experiments"
|
|
)
|
|
|
|
# Training parameters
|
|
eval_every: int = Field(10, description="Evaluate every N training steps")
|
|
perplexity_weight: float = Field(
|
|
0.7, description="Weight for perplexity in scoring"
|
|
)
|
|
quantum_weight: float = Field(
|
|
0.3, description="Weight for quantum-specific metrics"
|
|
)
|
|
|
|
# Hybrid model parameters
|
|
train_hybrid_model: bool = Field(
|
|
True, description="Whether to train the hybrid model"
|
|
)
|
|
learning_rate: float = Field(
|
|
1e-4, description="Learning rate for quantum parameters"
|
|
)
|
|
|
|
# Comparison parameters
|
|
compare_with_classical: bool = Field(
|
|
True, description="Compare hybrid with classical model"
|
|
)
|
|
|
|
|
|
class OptimizedQuantumLayer(torch.nn.Module):
|
|
"""Quantum circuit layer implementation using PennyLane."""
|
|
|
|
def __init__(self, n_qubits=8, n_layers=3):
|
|
super().__init__()
|
|
self.n_qubits = n_qubits
|
|
self.n_layers = n_layers
|
|
|
|
# Create a quantum device with the specified number of qubits
|
|
self.dev = qml.device("default.qubit", wires=n_qubits)
|
|
|
|
# Initialize quantum circuit parameters
|
|
self.params = torch.nn.Parameter(torch.randn(n_layers, n_qubits) * 0.1)
|
|
|
|
# Define the quantum circuit
|
|
@qml.qnode(self.dev, interface="torch")
|
|
def circuit(inputs, params):
|
|
# Embed classical data as quantum states
|
|
for i in range(self.n_qubits):
|
|
qml.RY(inputs[i], wires=i)
|
|
|
|
# Apply parameterized quantum layers
|
|
for layer in range(self.n_layers):
|
|
# Rotation gates with learnable parameters
|
|
for i in range(self.n_qubits):
|
|
qml.RY(params[layer, i], wires=i)
|
|
|
|
# Entanglement between qubits
|
|
for i in range(self.n_qubits - 1):
|
|
qml.CNOT(wires=[i, i + 1])
|
|
|
|
# Special case: connect last qubit to first qubit for full entanglement
|
|
if self.n_qubits > 1:
|
|
qml.CNOT(wires=[self.n_qubits - 1, 0])
|
|
|
|
# Measure all qubits in the computational basis
|
|
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
|
|
|
|
self.circuit = circuit
|
|
|
|
def forward(self, x):
|
|
# Handle single item or batch
|
|
if x.dim() == 1:
|
|
x = x.unsqueeze(0)
|
|
|
|
batch_size = x.shape[0]
|
|
results = []
|
|
|
|
# Process each item in the batch
|
|
for i in range(batch_size):
|
|
# Normalize input values to [-1, 1] for quantum embedding
|
|
x_norm = torch.tanh(x[i, : self.n_qubits])
|
|
# Get quantum circuit output
|
|
try:
|
|
quantum_out = torch.tensor(
|
|
self.circuit(x_norm, self.params), dtype=torch.float32
|
|
)
|
|
results.append(quantum_out)
|
|
except Exception as e:
|
|
print(f"Quantum circuit error: {e}")
|
|
# Fallback to random values if quantum circuit fails
|
|
results.append(torch.randn(self.n_qubits))
|
|
|
|
return torch.stack(results)
|
|
|
|
|
|
class OptimizedHybridModel(torch.nn.Module):
|
|
"""Hybrid model combining classical transformer model with quantum layers."""
|
|
|
|
def __init__(self, base_model_name, n_qubits=8, n_layers=3, vocab_size=50257):
|
|
super().__init__()
|
|
|
|
# We'll simulate the classical model behavior instead of loading full model
|
|
# to avoid memory issues in this environment
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = 768 # GPT2 hidden size
|
|
|
|
# Dimensionality reduction to quantum space
|
|
self.classical_to_quantum = torch.nn.Linear(self.hidden_size, n_qubits)
|
|
|
|
# Quantum circuit layers
|
|
self.quantum_layer1 = OptimizedQuantumLayer(n_qubits, n_layers)
|
|
self.quantum_layer2 = OptimizedQuantumLayer(n_qubits, n_layers)
|
|
|
|
# Map quantum output back to vocabulary space
|
|
self.quantum_to_logits = torch.nn.Linear(n_qubits, self.vocab_size)
|
|
|
|
# Mixing parameter
|
|
self.alpha = torch.nn.Parameter(torch.tensor([0.5]))
|
|
|
|
# Simple classical head for baseline comparison
|
|
self.classical_head = torch.nn.Linear(self.hidden_size, self.vocab_size)
|
|
|
|
def forward(self, hidden_states, return_classical=False):
|
|
"""
|
|
Args:
|
|
hidden_states: [batch_size, hidden_size] tensor
|
|
return_classical: if True, return classical logits for comparison
|
|
"""
|
|
# Classical pathway
|
|
classical_logits = self.classical_head(hidden_states)
|
|
|
|
if return_classical:
|
|
return classical_logits
|
|
|
|
# Quantum pathway
|
|
try:
|
|
# Reduce dimensionality for quantum processing
|
|
quantum_input = self.classical_to_quantum(hidden_states)
|
|
|
|
# Process through quantum circuits
|
|
quantum_output1 = self.quantum_layer1(quantum_input)
|
|
quantum_output2 = self.quantum_layer2(quantum_output1)
|
|
|
|
# Convert to vocabulary space
|
|
quantum_logits = self.quantum_to_logits(quantum_output2)
|
|
|
|
# Combine classical and quantum predictions
|
|
alpha = torch.sigmoid(self.alpha)
|
|
combined_logits = alpha * classical_logits + (1 - alpha) * quantum_logits
|
|
|
|
return combined_logits
|
|
|
|
except Exception as e:
|
|
print(f"Quantum forward pass error: {e}, using classical only")
|
|
return classical_logits
|
|
|
|
|
|
class QuantumHybridEnv(BaseEnv):
|
|
"""Environment for training and evaluating quantum-classical hybrid models."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: QuantumHybridConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config = config
|
|
self.metrics_buffer = {
|
|
"perplexity": [],
|
|
"quantum_coherence": [],
|
|
"combined_score": [],
|
|
"hybrid_loss": [],
|
|
"classical_loss": [],
|
|
"quantum_loss": [],
|
|
"alpha_value": [],
|
|
}
|
|
|
|
# Initialize models
|
|
self.hybrid_model = None
|
|
self.optimizer = None
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[QuantumHybridConfig, List[APIServerConfig]]:
|
|
"""Initialize default configuration."""
|
|
config = QuantumHybridConfig(
|
|
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
|
group_size=4,
|
|
use_wandb=True,
|
|
max_num_workers=2,
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=50,
|
|
batch_size=8,
|
|
steps_per_eval=10,
|
|
max_token_length=256,
|
|
n_qubits=8,
|
|
n_layers=3,
|
|
dataset_name="wikitext",
|
|
dataset_config="wikitext-2-raw-v1",
|
|
sequence_length=256,
|
|
base_model_name="gpt2",
|
|
train_hybrid_model=True,
|
|
compare_with_classical=True,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="NousResearch/Hermes-3-Llama-3.1-70B",
|
|
base_url="https://api.nousresearch.com/v1",
|
|
api_key="sk-JtnS49PZrw6W83WsxBhRTA",
|
|
server_type="openai",
|
|
timeout=600,
|
|
num_requests_for_eval=8,
|
|
),
|
|
]
|
|
|
|
return config, server_configs
|
|
|
|
async def setup(self):
|
|
"""Set up the environment and initialize models."""
|
|
# Use synthetic data to avoid HuggingFace issues
|
|
print("Setting up quantum-hybrid model training environment...")
|
|
|
|
# Simple tokenizer
|
|
class SimpleTokenizer:
|
|
def __init__(self):
|
|
self.vocab_size = 1000
|
|
self.pad_token = "[PAD]"
|
|
self.eos_token = "[EOS]"
|
|
|
|
def encode(self, text, **kwargs):
|
|
# Simple word-based tokenization
|
|
words = text.split()[:50]
|
|
return [hash(word) % self.vocab_size for word in words]
|
|
|
|
def apply_chat_template(self, messages, **kwargs):
|
|
return f"User: {messages[0]['content']}\nAssistant:"
|
|
|
|
self.tokenizer = SimpleTokenizer()
|
|
|
|
# Initialize hybrid model
|
|
self.hybrid_model = OptimizedHybridModel(
|
|
base_model_name=self.config.base_model_name,
|
|
n_qubits=self.config.n_qubits,
|
|
n_layers=self.config.n_layers,
|
|
vocab_size=self.tokenizer.vocab_size,
|
|
)
|
|
|
|
# Initialize optimizer for quantum parameters
|
|
self.optimizer = torch.optim.Adam(
|
|
self.hybrid_model.parameters(), lr=self.config.learning_rate
|
|
)
|
|
|
|
# Sample texts for training
|
|
self.sample_texts = [
|
|
"The quick brown fox jumps over the lazy dog.",
|
|
"Machine learning combines statistics and computer science.",
|
|
"Quantum computing uses quantum mechanics for computation.",
|
|
"Natural language processing analyzes human language.",
|
|
"Neural networks are inspired by biological brains.",
|
|
"Deep learning uses multiple layers of neural networks.",
|
|
"Artificial intelligence mimics human cognitive functions.",
|
|
"Computer vision enables machines to interpret images.",
|
|
"Robotics integrates mechanical and software engineering.",
|
|
"Data science extracts insights from large datasets.",
|
|
]
|
|
|
|
self.iter = 0
|
|
print("Setup complete! Ready to train quantum-hybrid models.")
|
|
|
|
async def get_next_item(self):
|
|
"""Get the next training item."""
|
|
import random
|
|
|
|
# Select random text
|
|
text = random.choice(self.sample_texts)
|
|
text = f"Example {self.iter + 1}: {text}"
|
|
self.iter += 1
|
|
|
|
# Create target for next-word prediction
|
|
tokens = self.tokenizer.encode(text)
|
|
|
|
# Convert to messages
|
|
user_msg = {"role": "user", "content": text}
|
|
prompt = tuple([frozenset(user_msg.items())])
|
|
|
|
return (prompt, tokens)
|
|
|
|
async def collect_trajectories(self, item: Tuple) -> Tuple[ScoredDataGroup, List]:
|
|
"""Generate responses and train hybrid model."""
|
|
prompt, target_tokens = item
|
|
user_content = dict(prompt[0])["content"]
|
|
|
|
# Generate from external model (Hermes-3-70B)
|
|
messages = [{"role": "user", "content": user_content}]
|
|
|
|
try:
|
|
prompt_text = f"User: {user_content}\nAssistant:"
|
|
completions = await self.server.completion(
|
|
prompt=prompt_text,
|
|
n=self.config.group_size,
|
|
max_tokens=50,
|
|
temperature=0.8,
|
|
)
|
|
except Exception as e:
|
|
print(f"API error: {e}, using fallback responses")
|
|
# Fallback responses
|
|
completions = type(
|
|
"obj",
|
|
(object,),
|
|
{
|
|
"choices": [
|
|
type(
|
|
"choice",
|
|
(object,),
|
|
{
|
|
"text": f"This is response {i+1} to: {user_content[:50]}..."
|
|
},
|
|
)()
|
|
for i in range(self.config.group_size)
|
|
]
|
|
},
|
|
)()
|
|
|
|
to_score = []
|
|
|
|
# Train hybrid model on each response
|
|
for completion in completions.choices:
|
|
completion_text = completion.text
|
|
|
|
# Create trajectory
|
|
trajectory_messages = messages.copy()
|
|
trajectory_messages.append(
|
|
{"role": "assistant", "content": completion_text}
|
|
)
|
|
|
|
# Train hybrid model
|
|
if self.config.train_hybrid_model:
|
|
await self._train_hybrid_model(completion_text, target_tokens)
|
|
|
|
to_score.append((tuple(trajectory_messages), target_tokens))
|
|
|
|
# Score the results
|
|
scored_data = await self.score(to_score)
|
|
|
|
return scored_data, []
|
|
|
|
async def _train_hybrid_model(self, generated_text: str, target_tokens: List[int]):
|
|
"""Train the hybrid model on generated text."""
|
|
try:
|
|
# Create synthetic hidden states (simulate transformer encoder output)
|
|
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
|
|
|
# Get predictions from hybrid and classical models
|
|
hybrid_logits = self.hybrid_model(hidden_states)
|
|
classical_logits = self.hybrid_model(hidden_states, return_classical=True)
|
|
|
|
# Create targets (simplified - use first few target tokens)
|
|
max_tokens = min(len(target_tokens), 10)
|
|
targets = torch.tensor(target_tokens[:max_tokens])
|
|
|
|
# Calculate losses
|
|
if max_tokens > 0:
|
|
# Repeat logits for each target token
|
|
hybrid_logits_expanded = hybrid_logits.repeat(max_tokens, 1)
|
|
classical_logits_expanded = classical_logits.repeat(max_tokens, 1)
|
|
|
|
hybrid_loss = F.cross_entropy(hybrid_logits_expanded, targets)
|
|
classical_loss = F.cross_entropy(classical_logits_expanded, targets)
|
|
|
|
# Optimize hybrid model
|
|
self.optimizer.zero_grad()
|
|
hybrid_loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# Log metrics
|
|
self.metrics_buffer["hybrid_loss"].append(hybrid_loss.item())
|
|
self.metrics_buffer["classical_loss"].append(classical_loss.item())
|
|
self.metrics_buffer["alpha_value"].append(
|
|
torch.sigmoid(self.hybrid_model.alpha).item()
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Training error: {e}")
|
|
|
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
|
"""Score responses using quantum-enhanced metrics."""
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
|
|
for item in rollout_group_data:
|
|
messages, targets = item
|
|
|
|
# Simple tokenization
|
|
generated_text = messages[-1]["content"]
|
|
tokens = self.tokenizer.encode(generated_text)
|
|
|
|
# Pad tokens to consistent length
|
|
max_len = 64
|
|
if len(tokens) < max_len:
|
|
tokens.extend([0] * (max_len - len(tokens)))
|
|
tokens = tokens[:max_len]
|
|
|
|
# Create masks (all ones for simplicity)
|
|
masks = [1] * len(tokens)
|
|
|
|
# Calculate scores
|
|
text_quality_score = min(len(generated_text) / 100, 1.0)
|
|
|
|
# Quantum coherence from hybrid model
|
|
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
|
try:
|
|
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
|
with torch.no_grad():
|
|
self.hybrid_model(hidden_states) # Run forward pass
|
|
quantum_contribution = (
|
|
1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
|
)
|
|
quantum_score = quantum_contribution + 0.3 * random.random()
|
|
except Exception:
|
|
quantum_score = 0.5 + 0.3 * random.random()
|
|
else:
|
|
quantum_score = 0.5 + 0.3 * random.random()
|
|
|
|
# Combined score
|
|
combined_score = (
|
|
self.config.perplexity_weight * text_quality_score
|
|
+ self.config.quantum_weight * quantum_score
|
|
)
|
|
|
|
# Update metrics
|
|
self.metrics_buffer["perplexity"].append(text_quality_score)
|
|
self.metrics_buffer["quantum_coherence"].append(quantum_score)
|
|
self.metrics_buffer["combined_score"].append(combined_score)
|
|
|
|
# Store results
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(2 * combined_score - 1) # Scale to [-1, 1]
|
|
|
|
return scores
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate the hybrid model."""
|
|
eval_scores = []
|
|
|
|
# Test on sample texts
|
|
for text in self.sample_texts[:5]:
|
|
try:
|
|
# Test hybrid model performance
|
|
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
|
hidden_states = torch.randn(1, self.hybrid_model.hidden_size)
|
|
|
|
with torch.no_grad():
|
|
hybrid_logits = self.hybrid_model(hidden_states)
|
|
classical_logits = self.hybrid_model(
|
|
hidden_states, return_classical=True
|
|
)
|
|
|
|
# Compare distributions
|
|
hybrid_entropy = -torch.sum(
|
|
F.softmax(hybrid_logits, dim=-1)
|
|
* F.log_softmax(hybrid_logits, dim=-1)
|
|
)
|
|
classical_entropy = -torch.sum(
|
|
F.softmax(classical_logits, dim=-1)
|
|
* F.log_softmax(classical_logits, dim=-1)
|
|
)
|
|
|
|
# Score based on entropy difference
|
|
score = 1.0 - abs(hybrid_entropy - classical_entropy) / max(
|
|
hybrid_entropy, classical_entropy
|
|
)
|
|
eval_scores.append(score.item())
|
|
|
|
except Exception as e:
|
|
print(f"Evaluation error: {e}")
|
|
eval_scores.append(0.5)
|
|
|
|
if eval_scores:
|
|
avg_score = sum(eval_scores) / len(eval_scores)
|
|
self.eval_metrics.append(("eval/hybrid_performance", avg_score))
|
|
|
|
# Log current alpha value
|
|
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
|
alpha_val = torch.sigmoid(self.hybrid_model.alpha).item()
|
|
self.eval_metrics.append(("eval/alpha_value", alpha_val))
|
|
self.eval_metrics.append(("eval/quantum_weight", 1 - alpha_val))
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to Weights & Biases."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Add buffered metrics
|
|
for metric_name, values in self.metrics_buffer.items():
|
|
if values:
|
|
wandb_metrics[f"train/{metric_name}"] = sum(values) / len(values)
|
|
if metric_name == "hybrid_loss" and len(values) > 1:
|
|
wandb_metrics[f"train/{metric_name}_std"] = np.std(values)
|
|
|
|
# Clear buffer
|
|
self.metrics_buffer = {key: [] for key in self.metrics_buffer}
|
|
|
|
# Add evaluation metrics
|
|
for name, value in self.eval_metrics:
|
|
wandb_metrics[name] = value
|
|
|
|
self.eval_metrics = []
|
|
|
|
# Log hybrid model parameters
|
|
if hasattr(self, "hybrid_model") and self.hybrid_model is not None:
|
|
wandb_metrics["model/alpha"] = torch.sigmoid(self.hybrid_model.alpha).item()
|
|
wandb_metrics["model/quantum_contribution"] = (
|
|
1 - torch.sigmoid(self.hybrid_model.alpha).item()
|
|
)
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
QuantumHybridEnv.cli()
|