mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
554 lines
20 KiB
Python
554 lines
20 KiB
Python
import random
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import pennylane as qml
|
|
from datasets import load_dataset
|
|
from pydantic import Field
|
|
from transformers import AutoTokenizer
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
BaseEnv,
|
|
BaseEnvConfig,
|
|
ScoredDataGroup,
|
|
)
|
|
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
|
|
|
|
|
class QuantumHybridConfig(BaseEnvConfig):
|
|
"""Configuration for the Quantum-Classical Hybrid Environment."""
|
|
|
|
# Quantum circuit parameters
|
|
n_qubits: int = Field(8, description="Number of qubits in the quantum circuit")
|
|
n_layers: int = Field(
|
|
3, description="Number of quantum layers to use in the circuit"
|
|
)
|
|
|
|
# Dataset parameters
|
|
dataset_name: str = Field(
|
|
"wikitext", description="Dataset to use for training/evaluation"
|
|
)
|
|
dataset_config: str = Field(
|
|
"wikitext-2-raw-v1", description="Dataset configuration"
|
|
)
|
|
sequence_length: int = Field(256, description="Length of sequences to process")
|
|
|
|
# Base model parameters
|
|
base_model_name: str = Field(
|
|
"gpt2", description="Base model for hybrid experiments"
|
|
)
|
|
|
|
# Training parameters
|
|
perplexity_weight: float = Field(
|
|
0.7, description="Weight for perplexity in scoring"
|
|
)
|
|
quantum_weight: float = Field(
|
|
0.3, description="Weight for quantum-specific metrics"
|
|
)
|
|
|
|
|
|
class QuantumTextAnalyzer:
|
|
"""Standalone quantum analyzer for text coherence measurement."""
|
|
|
|
def __init__(self, n_qubits=6):
|
|
self.n_qubits = n_qubits
|
|
self.dev = qml.device("default.qubit", wires=n_qubits)
|
|
|
|
@qml.qnode(self.dev)
|
|
def text_analysis_circuit(text_features):
|
|
# Embed text features as quantum states
|
|
for i in range(self.n_qubits):
|
|
qml.RY(text_features[i], wires=i)
|
|
|
|
# Create entanglement patterns
|
|
for i in range(self.n_qubits - 1):
|
|
qml.CNOT(wires=[i, i + 1])
|
|
|
|
# Ring closure
|
|
if self.n_qubits > 1:
|
|
qml.CNOT(wires=[self.n_qubits - 1, 0])
|
|
|
|
# Additional entanglement for complex analysis
|
|
for i in range(0, self.n_qubits - 2, 2):
|
|
qml.CNOT(wires=[i, i + 2])
|
|
|
|
# Measure coherence
|
|
return [qml.expval(qml.PauliZ(i)) for i in range(self.n_qubits)]
|
|
|
|
self.circuit = text_analysis_circuit
|
|
|
|
def analyze_text(self, text: str) -> float:
|
|
"""Analyze text and return quantum coherence score."""
|
|
try:
|
|
# Extract text features
|
|
text_len = min(len(text), 200) / 200.0
|
|
word_count = len(text.split()) / 100.0 if text.split() else 0.0
|
|
char_diversity = len(set(text.lower())) / 26.0 if text else 0.0
|
|
avg_word_len = (
|
|
np.mean([len(word) for word in text.split()]) / 15.0
|
|
if text.split()
|
|
else 0.0
|
|
)
|
|
|
|
# Additional features
|
|
punctuation_ratio = sum(1 for c in text if c in ".,!?;:") / max(
|
|
len(text), 1
|
|
)
|
|
uppercase_ratio = sum(1 for c in text if c.isupper()) / max(len(text), 1)
|
|
|
|
# Encode as quantum features
|
|
features = [
|
|
text_len * np.pi,
|
|
word_count * np.pi,
|
|
char_diversity * np.pi,
|
|
avg_word_len * np.pi,
|
|
punctuation_ratio * np.pi,
|
|
uppercase_ratio * np.pi,
|
|
]
|
|
|
|
# Run quantum analysis
|
|
measurements = self.circuit(features)
|
|
|
|
# Calculate coherence as normalized measurement variance
|
|
coherence = np.var(measurements) / (np.var(measurements) + 0.1)
|
|
return float(np.clip(coherence, 0.0, 1.0))
|
|
|
|
except Exception as e:
|
|
print(f"Quantum text analysis error: {e}")
|
|
# Fallback to simple heuristic
|
|
return min(1.0, len(text) / 100.0) * 0.7 + 0.2
|
|
|
|
|
|
class QuantumHybridEnv(BaseEnv):
|
|
"""Environment for training and evaluating quantum-classical hybrid models."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: QuantumHybridConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm=True,
|
|
testing=False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.config = config
|
|
self.metrics_buffer = {
|
|
"perplexity": [],
|
|
"quantum_coherence": [],
|
|
"combined_score": [],
|
|
"quantum_variance": [],
|
|
}
|
|
|
|
# Initialize eval_metrics list (required for BaseEnv)
|
|
self.eval_metrics = []
|
|
|
|
# Initialize quantum text analyzer
|
|
self.quantum_analyzer = QuantumTextAnalyzer(n_qubits=config.n_qubits)
|
|
|
|
# Track training iteration
|
|
self.iter = 0
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[QuantumHybridConfig, List[APIServerConfig]]:
|
|
"""Initialize default configuration."""
|
|
config = QuantumHybridConfig(
|
|
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
group_size=4,
|
|
use_wandb=True,
|
|
max_num_workers=-1,
|
|
rollout_server_url="http://localhost:8000", # Atropos server
|
|
total_steps=20,
|
|
batch_size=-1,
|
|
max_token_length=2048,
|
|
data_path_to_save_groups="data/quantum_hybrid.jsonl",
|
|
n_qubits=8,
|
|
n_layers=3,
|
|
dataset_name="wikitext",
|
|
dataset_config="wikitext-2-raw-v1",
|
|
sequence_length=256,
|
|
base_model_name="gpt2",
|
|
perplexity_weight=0.7,
|
|
quantum_weight=0.3,
|
|
)
|
|
|
|
# The server config here tells Atropos to route to your vLLM server
|
|
# This should match whatever model name the Atropos server expects
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name="hermes-3-8b", # This model name should be registered in Atropos
|
|
base_url="http://localhost:9001/v1", # Your vLLM server
|
|
api_key="x", # Placeholder for local server
|
|
timeout=300,
|
|
num_max_requests_at_once=8,
|
|
num_requests_for_eval=4,
|
|
health_check=False,
|
|
server_type="openai",
|
|
n_kwarg_is_ignored=False,
|
|
rolling_buffer_length=100,
|
|
),
|
|
]
|
|
|
|
return config, server_configs
|
|
|
|
async def setup(self):
|
|
"""Set up the environment, including loading datasets."""
|
|
print(
|
|
f"Setting up QuantumHybridEnv with quantum parameters: "
|
|
f"{self.config.n_qubits} qubits, {self.config.n_layers} layers"
|
|
)
|
|
|
|
# Initialize tokenizer
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_name)
|
|
if self.tokenizer.pad_token is None:
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
# Load dataset
|
|
try:
|
|
dataset = load_dataset(
|
|
self.config.dataset_name,
|
|
self.config.dataset_config,
|
|
split="train",
|
|
streaming=True,
|
|
)
|
|
self.dataset = dataset.take(10000) # Take first 10k examples
|
|
|
|
eval_dataset = load_dataset(
|
|
self.config.dataset_name,
|
|
self.config.dataset_config,
|
|
split="validation",
|
|
streaming=True,
|
|
)
|
|
self.eval_dataset = eval_dataset.take(100) # Take first 100 for eval
|
|
|
|
print(
|
|
f"Loaded dataset: {self.config.dataset_name}/{self.config.dataset_config}"
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Failed to load dataset: {e}")
|
|
# Create a mock dataset for testing
|
|
self.dataset = [
|
|
{
|
|
"text": f"Sample text {i} for quantum analysis. This text demonstrates complex linguistic patterns."
|
|
}
|
|
for i in range(1000)
|
|
]
|
|
self.eval_dataset = [
|
|
{"text": f"Eval text {i} with various complexity levels."}
|
|
for i in range(100)
|
|
]
|
|
print("Using mock dataset")
|
|
|
|
# Convert to lists for easier iteration
|
|
if hasattr(self.dataset, "__iter__"):
|
|
self.train_examples = list(self.dataset)
|
|
else:
|
|
self.train_examples = self.dataset
|
|
|
|
if hasattr(self.eval_dataset, "__iter__"):
|
|
self.eval_examples = list(self.eval_dataset)
|
|
else:
|
|
self.eval_examples = self.eval_dataset
|
|
|
|
print(f"Loaded {len(self.train_examples)} training examples")
|
|
print(f"Loaded {len(self.eval_examples)} evaluation examples")
|
|
|
|
async def get_next_item(self):
|
|
"""Get the next training item from the dataset."""
|
|
# Get next instance from the dataset
|
|
if self.iter >= len(self.train_examples):
|
|
self.iter = 0 # Reset to beginning
|
|
|
|
data_point = self.train_examples[self.iter]
|
|
self.iter += 1
|
|
|
|
# Process text data
|
|
if isinstance(data_point, dict) and "text" in data_point:
|
|
text = data_point["text"]
|
|
else:
|
|
text = str(data_point)
|
|
|
|
# Truncate text to reasonable length
|
|
text = text[: self.config.sequence_length * 4] # Allow for some context
|
|
|
|
# Create a simple continuation task
|
|
words = text.split()
|
|
if len(words) > 10:
|
|
# Split at a random point for continuation
|
|
split_idx = random.randint(5, min(len(words) - 5, 20))
|
|
prompt_text = " ".join(words[:split_idx])
|
|
target_text = " ".join(words[split_idx : split_idx + 20])
|
|
else:
|
|
prompt_text = text[: len(text) // 2]
|
|
target_text = text[len(text) // 2 :]
|
|
|
|
# Convert to messages format for Atropos
|
|
user_msg = {"role": "user", "content": f"Continue this text: {prompt_text}"}
|
|
|
|
# Return as (messages, target) tuple
|
|
return ([user_msg], target_text)
|
|
|
|
async def collect_trajectories(self, item: Tuple) -> Tuple[ScoredDataGroup, List]:
|
|
"""Generate and collect model responses for scoring."""
|
|
messages, target_text = item
|
|
|
|
print(f"Generating completions for: {messages[0]['content'][:50]}...")
|
|
|
|
to_score = []
|
|
|
|
# Generate multiple completions with different temperatures
|
|
temperatures = [0.6, 0.8, 1.0, 1.2][: self.config.group_size]
|
|
|
|
for i, temp in enumerate(temperatures):
|
|
try:
|
|
# Use the Atropos server to generate completions
|
|
completion = await self.server.completion(
|
|
prompt=self.tokenizer.apply_chat_template(messages, tokenize=False),
|
|
n=1, # Generate one completion at a time
|
|
max_tokens=min(100, self.config.max_token_length // 4),
|
|
temperature=temp,
|
|
)
|
|
|
|
completion_text = completion.choices[0].text.strip()
|
|
print(
|
|
f"Generated completion {i+1} (T={temp}): {completion_text[:50]}..."
|
|
)
|
|
|
|
# Create trajectory messages
|
|
trajectory_messages = messages.copy()
|
|
trajectory_messages.append(
|
|
{"role": "assistant", "content": completion_text}
|
|
)
|
|
|
|
# Add to scoring queue
|
|
to_score.append(
|
|
(
|
|
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
|
|
target_text,
|
|
completion_text,
|
|
)
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Error generating completion {i+1} with temperature {temp}: {e}")
|
|
# Create a mock completion as fallback
|
|
mock_text = (
|
|
f"Mock quantum-enhanced response {i+1}: This demonstrates "
|
|
f"coherent language generation with temperature {temp}."
|
|
)
|
|
trajectory_messages = messages.copy()
|
|
trajectory_messages.append({"role": "assistant", "content": mock_text})
|
|
to_score.append(
|
|
(
|
|
tuple([frozenset(msg.items()) for msg in trajectory_messages]),
|
|
target_text,
|
|
mock_text,
|
|
)
|
|
)
|
|
|
|
# Score the generated trajectories
|
|
scored_data = await self.score(to_score)
|
|
|
|
return scored_data, []
|
|
|
|
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
|
"""Score model outputs based on perplexity and quantum metrics."""
|
|
if not rollout_group_data:
|
|
return None
|
|
|
|
scores = ScoredDataGroup()
|
|
scores["tokens"] = []
|
|
scores["masks"] = []
|
|
scores["scores"] = []
|
|
|
|
print(f"Scoring {len(rollout_group_data)} completions...")
|
|
|
|
for item in rollout_group_data:
|
|
frozen_messages, target_text, completion_text = item
|
|
|
|
# Convert frozen messages back to regular format
|
|
messages = [dict(frozen_msg) for frozen_msg in frozen_messages]
|
|
|
|
# Convert messages to tokens
|
|
try:
|
|
tokenized = tokenize_for_trainer(self.tokenizer, messages)
|
|
tokens = tokenized["tokens"]
|
|
masks = tokenized["masks"]
|
|
except Exception as e:
|
|
print(f"Tokenization error: {e}")
|
|
continue
|
|
|
|
# Calculate text similarity score (proxy for perplexity)
|
|
completion_words = set(completion_text.lower().split())
|
|
target_words = set(target_text.lower().split())
|
|
|
|
# Jaccard similarity
|
|
if completion_words and target_words:
|
|
intersection = len(completion_words & target_words)
|
|
union = len(completion_words | target_words)
|
|
similarity_score = intersection / union if union > 0 else 0.0
|
|
else:
|
|
similarity_score = 0.0
|
|
|
|
# Length penalty/bonus
|
|
len_ratio = len(completion_text) / max(len(target_text), 1)
|
|
len_score = 1.0 - abs(1.0 - len_ratio) if len_ratio > 0 else 0.0
|
|
|
|
# Combined perplexity-like score
|
|
perplexity_score = 0.7 * similarity_score + 0.3 * len_score
|
|
|
|
# Quantum coherence analysis
|
|
quantum_coherence = self.quantum_analyzer.analyze_text(completion_text)
|
|
|
|
# Calculate quantum variance for additional insight
|
|
quantum_variance = (
|
|
abs(quantum_coherence - 0.5) * 2
|
|
) # Distance from maximum entropy
|
|
|
|
# Combined score using weighted sum
|
|
combined_score = (
|
|
self.config.perplexity_weight * perplexity_score
|
|
+ self.config.quantum_weight * quantum_coherence
|
|
)
|
|
|
|
print(
|
|
f" Similarity: {similarity_score:.3f}, Quantum: {quantum_coherence:.3f}, "
|
|
f"Combined: {combined_score:.3f}"
|
|
)
|
|
|
|
# Update metrics buffer
|
|
self.metrics_buffer["perplexity"].append(perplexity_score)
|
|
self.metrics_buffer["quantum_coherence"].append(quantum_coherence)
|
|
self.metrics_buffer["combined_score"].append(combined_score)
|
|
self.metrics_buffer["quantum_variance"].append(quantum_variance)
|
|
|
|
# Store data for training (scale to [-1, 1] range)
|
|
scores["tokens"].append(tokens)
|
|
scores["masks"].append(masks)
|
|
scores["scores"].append(2 * combined_score - 1)
|
|
|
|
return scores if scores["scores"] else None
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate the model on a test set."""
|
|
print("Running quantum-enhanced evaluation...")
|
|
eval_scores = []
|
|
quantum_scores = []
|
|
|
|
# Get evaluation examples
|
|
eval_examples = self.eval_examples[: min(5, len(self.eval_examples))]
|
|
|
|
for example in eval_examples:
|
|
try:
|
|
# Process text
|
|
if isinstance(example, dict) and "text" in example:
|
|
text = example["text"]
|
|
else:
|
|
text = str(example)
|
|
|
|
# Create continuation task
|
|
words = text.split()
|
|
if len(words) > 8:
|
|
split_idx = len(words) // 2
|
|
prompt_text = " ".join(words[:split_idx])
|
|
target_text = " ".join(words[split_idx:])
|
|
else:
|
|
prompt_text = text[: len(text) // 2]
|
|
target_text = text[len(text) // 2 :]
|
|
|
|
# Create messages
|
|
messages = [
|
|
{"role": "user", "content": f"Continue this text: {prompt_text}"}
|
|
]
|
|
|
|
# Generate completion
|
|
completion = await self.server.completion(
|
|
prompt=self.tokenizer.apply_chat_template(messages, tokenize=False),
|
|
n=1,
|
|
max_tokens=50,
|
|
temperature=0.7,
|
|
split="eval",
|
|
)
|
|
|
|
completion_text = completion.choices[0].text.strip()
|
|
|
|
# Calculate metrics
|
|
completion_words = set(completion_text.lower().split())
|
|
target_words = set(target_text.lower().split())
|
|
|
|
if completion_words and target_words:
|
|
intersection = len(completion_words & target_words)
|
|
union = len(completion_words | target_words)
|
|
similarity = intersection / union if union > 0 else 0.0
|
|
else:
|
|
similarity = 0.0
|
|
|
|
# Quantum analysis
|
|
quantum_coherence = self.quantum_analyzer.analyze_text(completion_text)
|
|
|
|
# Combined evaluation score
|
|
eval_score = (
|
|
self.config.perplexity_weight * similarity
|
|
+ self.config.quantum_weight * quantum_coherence
|
|
)
|
|
|
|
eval_scores.append(eval_score)
|
|
quantum_scores.append(quantum_coherence)
|
|
|
|
except Exception as e:
|
|
print(f"Evaluation error: {e}")
|
|
continue
|
|
|
|
# Log evaluation metrics
|
|
if eval_scores:
|
|
avg_eval_score = sum(eval_scores) / len(eval_scores)
|
|
avg_quantum_score = sum(quantum_scores) / len(quantum_scores)
|
|
|
|
self.eval_metrics.append(("eval/combined_score", avg_eval_score))
|
|
self.eval_metrics.append(("eval/quantum_coherence", avg_quantum_score))
|
|
self.eval_metrics.append(
|
|
("eval/perplexity_weight", self.config.perplexity_weight)
|
|
)
|
|
self.eval_metrics.append(
|
|
("eval/quantum_weight", self.config.quantum_weight)
|
|
)
|
|
|
|
print(
|
|
f"Evaluation complete: avg_score={avg_eval_score:.3f}, avg_quantum={avg_quantum_score:.3f}"
|
|
)
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics to Weights & Biases."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
# Calculate and add metrics from buffer
|
|
for metric_name, values in self.metrics_buffer.items():
|
|
if values:
|
|
wandb_metrics[f"train/{metric_name}_avg"] = sum(values) / len(values)
|
|
wandb_metrics[f"train/{metric_name}_max"] = max(values)
|
|
wandb_metrics[f"train/{metric_name}_min"] = min(values)
|
|
wandb_metrics[f"train/{metric_name}_std"] = np.std(values)
|
|
|
|
# Add quantum-specific metrics
|
|
wandb_metrics["quantum/n_qubits"] = self.config.n_qubits
|
|
wandb_metrics["quantum/n_layers"] = self.config.n_layers
|
|
wandb_metrics["quantum/weight"] = self.config.quantum_weight
|
|
wandb_metrics["train/perplexity_weight"] = self.config.perplexity_weight
|
|
|
|
# Clear the buffer
|
|
self.metrics_buffer = {key: [] for key in self.metrics_buffer}
|
|
|
|
# Add evaluation metrics
|
|
for name, value in self.eval_metrics:
|
|
wandb_metrics[name] = value
|
|
|
|
self.eval_metrics = []
|
|
|
|
# Log to wandb using the parent method
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Launch the environment with CLI arguments
|
|
QuantumHybridEnv.cli()
|