mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
added student agent code
This commit is contained in:
parent
d17a1b0b16
commit
1d260f951c
8 changed files with 1360 additions and 86 deletions
5
.env
5
.env
|
|
@ -0,0 +1,5 @@
|
|||
# API Keys
|
||||
NOUS_API_KEY="eyJhbGciOiJIUzI1NiIsImtpZCI6IlV6SXJWd1h0dnprLVRvdzlLZWstc0M1akptWXBvX1VaVkxUZlpnMDRlOFUiLCJ0eXAiOiJKV1QifQ.eyJzdWIiOiJnaXRodWJ8MTE1OTI1MDcxIiwic2NvcGUiOiJvcGVuaWQgb2ZmbGluZV9hY2Nlc3MiLCJpc3MiOiJhcGlfa2V5X2lzc3VlciIsImF1ZCI6WyJodHRwczovL25lYml1cy1pbmZlcmVuY2UuZXUuYXV0aDAuY29tL2FwaS92Mi8iXSwiZXhwIjoxOTA1Mjg4ODM4LCJ1dWlkIjoiMDg2YWYxOGYtNGE2ZS00MWFlLTg5YzMtNGFkNWUyNjEwMzQ4IiwibmFtZSI6IkV4YW1DcmFmdCIsImV4cGlyZXNfYXQiOiIyMDMwLTA1LTE3VDIyOjUzOjU4KzAwMDAifQ.q1b45WcNS1wqprvSafOyV_FF0SknoQlASuQQNFrrsM0"
|
||||
|
||||
# Environment Settings
|
||||
LOG_LEVEL=INFO
|
||||
212
.gitignore
vendored
Normal file
212
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
# Environment variables
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
|
||||
# API Keys and sensitive files
|
||||
*.key
|
||||
secrets/
|
||||
config/api_keys.json
|
||||
|
||||
# Python cache files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# VS Code
|
||||
.vscode/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Windows
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
*.stackdump
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Linux
|
||||
*~
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Logs and databases
|
||||
*.log
|
||||
*.sqlite
|
||||
*.db
|
||||
|
||||
# AI/ML specific
|
||||
wandb/
|
||||
lightning_logs/
|
||||
checkpoints/
|
||||
models/
|
||||
*.pkl
|
||||
*.joblib
|
||||
*.h5
|
||||
*.hdf5
|
||||
*.ckpt
|
||||
|
||||
# Data files (uncomment if you don't want to track large datasets)
|
||||
# data/
|
||||
# *.csv
|
||||
# *.json
|
||||
# *.jsonl
|
||||
# *.gz
|
||||
# *.zip
|
||||
|
||||
# Atropos specific
|
||||
experiments/
|
||||
runs/
|
||||
outputs/
|
||||
|
||||
# Local configuration files
|
||||
local_config.json
|
||||
user_settings.json
|
||||
|
||||
# Backup files
|
||||
*.backup
|
||||
*.bak
|
||||
*.orig
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
{
|
||||
"student_id": "student001",
|
||||
"target_grade": "11th grade",
|
||||
"learning_goal": "Understand and apply key concepts in linear algebra",
|
||||
"prior_knowledge_level": "intermediate",
|
||||
"topics": [
|
||||
{
|
||||
"name": "vectors",
|
||||
"proficiency": 0.65,
|
||||
"sub_topics": ["vector addition", "scalar multiplication", "dot product", "cross product"]
|
||||
},
|
||||
{
|
||||
"name": "matrices",
|
||||
"proficiency": 0.50,
|
||||
"sub_topics": ["matrix operations", "determinants", "inverse matrices", "matrix transformations"]
|
||||
},
|
||||
{
|
||||
"name": "linear_systems",
|
||||
"proficiency": 0.40,
|
||||
"sub_topics": ["gaussian elimination", "cramer's rule", "homogeneous systems"]
|
||||
},
|
||||
{
|
||||
"name": "eigenvalues_eigenvectors",
|
||||
"proficiency": 0.30,
|
||||
"sub_topics": ["characteristic polynomial", "diagonalization", "applications"]
|
||||
}
|
||||
],
|
||||
"preferred_learning_style": "visual",
|
||||
"attention_span_minutes": 45,
|
||||
"target_milestones": [
|
||||
"Solve systems of linear equations confidently",
|
||||
"Understand eigenvalues and eigenvectors conceptually",
|
||||
"Apply matrix transformations to geometric problems"
|
||||
],
|
||||
"current_avg_score": 73
|
||||
}
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
requests>=2.28.0
|
||||
python-dotenv>=0.19.0
|
||||
gymnasium>=0.28.1
|
||||
numpy>=1.24.0
|
||||
# Atropos would be included here in a full implementation
|
||||
# atropos>=0.1.0
|
||||
150
tutor_rl_agent/agents/Student_agent.py
Normal file
150
tutor_rl_agent/agents/Student_agent.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import os
|
||||
import openai
|
||||
|
||||
|
||||
# StudentAgent is responsible for answering teacher questions, evaluating teacher feedback,
|
||||
# tracking weak areas, and generating performance summaries. It interacts with a TeacherAgent
|
||||
# in an Atropos-compatible reinforcement learning loop.
|
||||
class StudentAgent:
|
||||
def __init__(self, profile):
|
||||
self.profile = profile
|
||||
self.weak_areas = {}
|
||||
self.log = []
|
||||
openai.api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
def generate_answer(self, question):
|
||||
"""
|
||||
Generates a student-like answer to the teacher's question using the OpenAI LLM.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are a student learning {self.profile['subject']} at {self.profile['difficulty']} level. Your goal is: {self.profile['goal']}."},
|
||||
{"role": "user", "content": f"Teacher asks: {question}"}
|
||||
]
|
||||
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=messages
|
||||
)
|
||||
answer = response.choices[0].message.content.strip()
|
||||
self.log.append({"question": question, "student_answer": answer})
|
||||
return answer
|
||||
|
||||
|
||||
def evaluate_teacher_effectiveness(self, question, explanation, student_answer):
|
||||
"""
|
||||
Evaluates how effective the teacher's explanation was based on the student's response and learning goal.
|
||||
Returns a score between 0.0 and 1.0.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": f"""
|
||||
Teacher asked: {question}
|
||||
Teacher explained: {explanation}
|
||||
Student answered: {student_answer}
|
||||
Student's learning goal: {self.profile['goal']}
|
||||
|
||||
|
||||
Rate the teacher's effectiveness in helping the student learn the concept.
|
||||
Respond with a number between 0.0 (not helpful) and 1.0 (very effective). Only output the number.
|
||||
"""}
|
||||
]
|
||||
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=messages
|
||||
)
|
||||
return float(response.choices[0].message.content.strip())
|
||||
|
||||
|
||||
def update_weak_areas(self, question, student_answer):
|
||||
"""
|
||||
Tracks weak topic areas by comparing the question and student's answer.
|
||||
"""
|
||||
for topic in ["loops", "recursion", "lists", "variables", "functions"]:
|
||||
if topic in question.lower() and topic not in student_answer.lower():
|
||||
self.weak_areas[topic] = self.weak_areas.get(topic, 0) + 1
|
||||
return self.weak_areas
|
||||
def compare_answers(self, original, revised):
|
||||
"""
|
||||
Compares original and revised answers. Returns 'improved' or 'no change'.
|
||||
"""
|
||||
if len(revised) > len(original) and revised != original:
|
||||
return "improved"
|
||||
return "no change"
|
||||
|
||||
|
||||
def summarize_performance(self):
|
||||
"""
|
||||
Summarizes teacher performance based on all interactions.
|
||||
Returns average score, total questions, and weak areas.
|
||||
"""
|
||||
total = len(self.log)
|
||||
scores = [entry.get("score", 0.0) for entry in self.log if "score" in entry]
|
||||
avg_score = sum(scores) / len(scores) if scores else 0.0
|
||||
return {
|
||||
"total_questions": total,
|
||||
"avg_teacher_score": round(avg_score, 2),
|
||||
"weak_areas": self.weak_areas
|
||||
}
|
||||
def revise_answer_based_on_explanation(self, question, explanation):
|
||||
"""
|
||||
Generates a revised answer based on the teacher's explanation.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": f"You are a student learning {self.profile['subject']} at {self.profile['difficulty']} level. Your goal is: {self.profile['goal']}."},
|
||||
{"role": "user", "content": f"""
|
||||
Teacher asked: {question}
|
||||
Teacher explained: {explanation}
|
||||
|
||||
|
||||
Using the teacher's explanation, try to answer the original question again.
|
||||
"""}
|
||||
]
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-4",
|
||||
messages=messages
|
||||
)
|
||||
revised = response.choices[0].message.content.strip()
|
||||
self.log.append({"question": question, "revised_answer": revised})
|
||||
return revised
|
||||
|
||||
|
||||
def reset_log(self):
|
||||
"""
|
||||
Resets the internal log and weak areas for a new episode.
|
||||
"""
|
||||
self.log = []
|
||||
self.weak_areas = {}
|
||||
|
||||
|
||||
def log_score(self, score):
|
||||
"""
|
||||
Logs the teacher score for the most recent interaction.
|
||||
"""
|
||||
if self.log:
|
||||
self.log[-1]["score"] = score
|
||||
|
||||
|
||||
def get_last_answer(self):
|
||||
"""
|
||||
Returns the most recent student answer, if available.
|
||||
"""
|
||||
if self.log and "student_answer" in self.log[-1]:
|
||||
return self.log[-1]["student_answer"]
|
||||
return None
|
||||
|
||||
|
||||
def track_revision_success(self, question, revised_answer):
|
||||
"""
|
||||
Checks if revised answer addressed previously missed weak topics.
|
||||
Removes resolved topics from weak_areas.
|
||||
"""
|
||||
resolved_topics = []
|
||||
for topic in list(self.weak_areas.keys()):
|
||||
if topic in question.lower() and topic in revised_answer.lower():
|
||||
resolved_topics.append(topic)
|
||||
del self.weak_areas[topic]
|
||||
return resolved_topics
|
||||
|
||||
|
|
@ -1,89 +1,445 @@
|
|||
import openai
|
||||
import random
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
from typing import Dict, List, Any, Tuple, Optional
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
class TeacherAgent:
|
||||
def __init__(self, model="gpt-4"):
|
||||
self.model = model
|
||||
self.student_profile = {}
|
||||
self.weak_areas = []
|
||||
self.question_history = []
|
||||
"""
|
||||
TeacherAgent for the LLM-Based Interactive Teacher-Student Tutor Environment.
|
||||
|
||||
This agent is responsible for:
|
||||
1. Generating appropriate questions based on student profile
|
||||
2. Evaluating student responses
|
||||
3. Providing explanations for incorrect answers
|
||||
4. Adapting teaching strategy based on student performance
|
||||
"""
|
||||
|
||||
def __init__(self, profile_path: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
Initialize the TeacherAgent with a student profile.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the JSON file containing student profile
|
||||
api_key: Optional API key for LLM (defaults to environment variable)
|
||||
"""
|
||||
# Load student profile
|
||||
self.profile = self._load_profile(profile_path)
|
||||
|
||||
# Set up LLM client
|
||||
self.api_key = api_key or os.getenv("NOUS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("No API key provided. Set NOUS_API_KEY in .env file or pass as argument.")
|
||||
|
||||
# Nous API endpoint
|
||||
self.api_endpoint = "https://api.nousresearch.com/v1/chat/completions"
|
||||
self.history = []
|
||||
|
||||
# Track student performance metrics
|
||||
self.student_metrics = {
|
||||
"questions_asked": 0,
|
||||
"correct_answers": 0,
|
||||
"topic_performance": {},
|
||||
"difficulty_distribution": {"easy": 0, "medium": 0, "hard": 0}
|
||||
}
|
||||
|
||||
# Initialize the topics from profile
|
||||
for topic in self.profile.get("topics", []):
|
||||
self.student_metrics["topic_performance"][topic["name"]] = {
|
||||
"questions": 0,
|
||||
"correct": 0,
|
||||
"accuracy": 0.0
|
||||
}
|
||||
|
||||
def _load_profile(self, profile_path: str) -> Dict[str, Any]:
|
||||
"""Load student profile from JSON file."""
|
||||
try:
|
||||
with open(profile_path, 'r') as file:
|
||||
profile = json.load(file)
|
||||
return profile
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load profile from {profile_path}: {e}")
|
||||
|
||||
def _call_llm(self, prompt: str) -> str:
|
||||
"""Make a call to the Nous Research API."""
|
||||
try:
|
||||
response = requests.post(
|
||||
self.api_endpoint,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": "hermes-3-405b-instruct", # Using Hermes-3-405B model
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an expert teacher assistant."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"max_tokens": 1000,
|
||||
"temperature": 0.7
|
||||
},
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
print(f"API Error: {response.status_code} - {response.text}")
|
||||
return "I couldn't generate a response at this time."
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Network error calling Nous API: {e}")
|
||||
return "I couldn't generate a response due to a network error."
|
||||
except Exception as e:
|
||||
print(f"Error calling Nous API: {e}")
|
||||
return "I couldn't generate a response at this time."
|
||||
|
||||
def generate_question(self, topic: Optional[str] = None,
|
||||
difficulty: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a multiple-choice question based on student profile and history.
|
||||
|
||||
Args:
|
||||
topic: Optional topic override
|
||||
difficulty: Optional difficulty override ('easy', 'medium', 'hard')
|
||||
|
||||
Returns:
|
||||
Dict containing question, options, correct answer, and explanation
|
||||
"""
|
||||
# Use provided topic/difficulty or select based on student performance
|
||||
selected_topic = topic or self._select_topic()
|
||||
selected_difficulty = difficulty or self._select_difficulty(selected_topic)
|
||||
|
||||
# Craft prompt for LLM to generate a question
|
||||
prompt = self._craft_question_prompt(selected_topic, selected_difficulty)
|
||||
|
||||
# Get response from LLM
|
||||
response = self._call_llm(prompt)
|
||||
|
||||
# Parse the response to extract question components
|
||||
try:
|
||||
question_data = self._parse_question_response(response)
|
||||
|
||||
# Update metrics
|
||||
self.student_metrics["questions_asked"] += 1
|
||||
self.student_metrics["difficulty_distribution"][selected_difficulty] += 1
|
||||
self.student_metrics["topic_performance"][selected_topic]["questions"] += 1
|
||||
|
||||
# Add metadata
|
||||
question_data["topic"] = selected_topic
|
||||
question_data["difficulty"] = selected_difficulty
|
||||
|
||||
# Add to history
|
||||
self.history.append({
|
||||
"type": "question",
|
||||
"data": question_data
|
||||
})
|
||||
|
||||
return question_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse question response: {e}")
|
||||
return {
|
||||
"error": "Failed to generate valid question",
|
||||
"raw_response": response
|
||||
}
|
||||
|
||||
def _craft_question_prompt(self, topic: str, difficulty: str) -> str:
|
||||
"""Craft a prompt for the LLM to generate a multiple-choice question."""
|
||||
grade_level = self.profile.get("target_grade", "high school")
|
||||
|
||||
prompt = f"""
|
||||
Please create a {difficulty} level multiple-choice question about {topic} appropriate for a {grade_level} student.
|
||||
|
||||
The question should have:
|
||||
1. A clear, concise question statement
|
||||
2. Four possible answer options (A, B, C, D)
|
||||
3. The correct answer (just the letter)
|
||||
4. A detailed explanation of why the correct answer is right and why the others are wrong
|
||||
|
||||
Format your response as a JSON object with the following structure:
|
||||
{{
|
||||
"question": "...",
|
||||
"options": {{
|
||||
"A": "...",
|
||||
"B": "...",
|
||||
"C": "...",
|
||||
"D": "..."
|
||||
}},
|
||||
"correct_answer": "A/B/C/D",
|
||||
"explanation": "..."
|
||||
}}
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def _parse_question_response(self, response: str) -> Dict[str, Any]:
|
||||
"""Parse the LLM response to extract question data."""
|
||||
# Try to find JSON content within the response
|
||||
try:
|
||||
# Extract just the JSON part if there's additional text
|
||||
start = response.find('{')
|
||||
end = response.rfind('}') + 1
|
||||
if start >= 0 and end > start:
|
||||
json_str = response[start:end]
|
||||
return json.loads(json_str)
|
||||
else:
|
||||
return json.loads(response) # Try parsing the whole response
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Could not parse JSON from LLM response")
|
||||
|
||||
def _select_topic(self) -> str:
|
||||
"""
|
||||
Select a topic to focus on based on student performance.
|
||||
Prioritizes topics where student is struggling.
|
||||
"""
|
||||
if not self.history:
|
||||
# For first question, select a random topic from profile
|
||||
return self.profile["topics"][0]["name"]
|
||||
|
||||
# Find topics with lowest accuracy
|
||||
topic_accuracies = {}
|
||||
for topic, data in self.student_metrics["topic_performance"].items():
|
||||
if data["questions"] > 0:
|
||||
accuracy = data["correct"] / data["questions"]
|
||||
topic_accuracies[topic] = accuracy
|
||||
else:
|
||||
# Prioritize untested topics
|
||||
topic_accuracies[topic] = 0.0
|
||||
|
||||
# Get topic with lowest accuracy (or random if tied)
|
||||
import random
|
||||
min_accuracy = min(topic_accuracies.values())
|
||||
weakest_topics = [t for t, a in topic_accuracies.items() if a == min_accuracy]
|
||||
return random.choice(weakest_topics)
|
||||
|
||||
def _select_difficulty(self, topic: str) -> str:
|
||||
"""
|
||||
Select appropriate difficulty based on student performance in the topic.
|
||||
"""
|
||||
topic_data = self.student_metrics["topic_performance"].get(topic, {})
|
||||
|
||||
# If no questions asked yet, start with medium difficulty
|
||||
if topic_data.get("questions", 0) == 0:
|
||||
return "medium"
|
||||
|
||||
# Calculate accuracy for the topic
|
||||
accuracy = topic_data.get("correct", 0) / topic_data.get("questions", 1)
|
||||
|
||||
# Adjust difficulty based on accuracy
|
||||
if accuracy < 0.4:
|
||||
return "easy" # Student is struggling, make it easier
|
||||
elif accuracy > 0.8:
|
||||
return "hard" # Student is doing well, make it harder
|
||||
else:
|
||||
return "medium" # Keep at medium difficulty
|
||||
|
||||
def evaluate_response(self, question_id: int, student_answer: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate student's answer to a question.
|
||||
|
||||
Args:
|
||||
question_id: Index of the question in history
|
||||
student_answer: Student's selected answer (A/B/C/D)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
# Retrieve the question from history
|
||||
if question_id >= len(self.history) or self.history[question_id]["type"] != "question":
|
||||
return {"error": "Invalid question ID"}
|
||||
|
||||
question_data = self.history[question_id]["data"]
|
||||
|
||||
# Check if answer is correct
|
||||
is_correct = student_answer.upper() == question_data["correct_answer"].upper()
|
||||
|
||||
# Update metrics
|
||||
if is_correct:
|
||||
self.student_metrics["correct_answers"] += 1
|
||||
topic = question_data["topic"]
|
||||
self.student_metrics["topic_performance"][topic]["correct"] += 1
|
||||
|
||||
# Update accuracy
|
||||
questions = self.student_metrics["topic_performance"][topic]["questions"]
|
||||
correct = self.student_metrics["topic_performance"][topic]["correct"]
|
||||
self.student_metrics["topic_performance"][topic]["accuracy"] = correct / questions
|
||||
|
||||
# Prepare evaluation response
|
||||
evaluation = {
|
||||
"is_correct": is_correct,
|
||||
"correct_answer": question_data["correct_answer"],
|
||||
"explanation": question_data["explanation"]
|
||||
}
|
||||
|
||||
# If incorrect, generate a tailored explanation
|
||||
if not is_correct:
|
||||
selected_option = student_answer.upper()
|
||||
if selected_option in question_data["options"]:
|
||||
prompt = self._craft_explanation_prompt(
|
||||
question_data["question"],
|
||||
question_data["options"],
|
||||
question_data["correct_answer"],
|
||||
selected_option
|
||||
)
|
||||
tailored_explanation = self._call_llm(prompt)
|
||||
evaluation["tailored_explanation"] = tailored_explanation
|
||||
|
||||
# Add to history
|
||||
self.history.append({
|
||||
"type": "evaluation",
|
||||
"data": {
|
||||
"question_id": question_id,
|
||||
"student_answer": student_answer,
|
||||
"evaluation": evaluation
|
||||
}
|
||||
})
|
||||
|
||||
return evaluation
|
||||
|
||||
def _craft_explanation_prompt(self, question: str, options: Dict[str, str],
|
||||
correct_answer: str, selected_answer: str) -> str:
|
||||
"""Craft a prompt for the LLM to generate a tailored explanation."""
|
||||
prompt = f"""
|
||||
A student answered the following multiple-choice question incorrectly:
|
||||
|
||||
Question: {question}
|
||||
|
||||
Options:
|
||||
A: {options.get('A', 'N/A')}
|
||||
B: {options.get('B', 'N/A')}
|
||||
C: {options.get('C', 'N/A')}
|
||||
D: {options.get('D', 'N/A')}
|
||||
|
||||
The correct answer is {correct_answer}: {options.get(correct_answer, 'N/A')}
|
||||
|
||||
The student selected {selected_answer}: {options.get(selected_answer, 'N/A')}
|
||||
|
||||
Please provide a detailed, supportive explanation that:
|
||||
1. Explains why their answer is incorrect
|
||||
2. Identifies the misconception that might have led them to this answer
|
||||
3. Clearly explains why the correct answer is right
|
||||
4. Provides an additional example or analogy to reinforce the concept
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of student performance across topics and difficulties.
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
total_questions = self.student_metrics["questions_asked"]
|
||||
total_correct = self.student_metrics["correct_answers"]
|
||||
|
||||
summary = {
|
||||
"total_questions": total_questions,
|
||||
"total_correct": total_correct,
|
||||
"overall_accuracy": total_correct / total_questions if total_questions > 0 else 0,
|
||||
"topic_performance": self.student_metrics["topic_performance"],
|
||||
"difficulty_distribution": self.student_metrics["difficulty_distribution"]
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def generate_adaptive_lesson_plan(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an adaptive lesson plan based on student performance.
|
||||
|
||||
Returns:
|
||||
Dictionary with recommended topics and strategies
|
||||
"""
|
||||
# Get performance summary
|
||||
performance = self.get_performance_summary()
|
||||
|
||||
# Identify weakest topics (below 60% accuracy)
|
||||
weak_topics = []
|
||||
for topic, data in performance["topic_performance"].items():
|
||||
if data["questions"] > 0 and data["accuracy"] < 0.6:
|
||||
weak_topics.append({
|
||||
"topic": topic,
|
||||
"accuracy": data["accuracy"],
|
||||
"questions": data["questions"]
|
||||
})
|
||||
|
||||
# Sort weak topics by accuracy (ascending)
|
||||
weak_topics.sort(key=lambda x: x["accuracy"])
|
||||
|
||||
# Craft prompt for lesson plan
|
||||
prompt = f"""
|
||||
Based on a student's performance data, generate an adaptive lesson plan.
|
||||
|
||||
Overall accuracy: {performance["overall_accuracy"]:.2f}
|
||||
|
||||
Topic performance:
|
||||
{json.dumps(performance["topic_performance"], indent=2)}
|
||||
|
||||
Please create a focused lesson plan that:
|
||||
1. Prioritizes the weakest topics (if any)
|
||||
2. Recommends specific learning activities for each weak topic
|
||||
3. Suggests a balanced approach to reinforce strong topics while improving weak ones
|
||||
4. Includes 2-3 specific example questions/exercises for the weakest topic
|
||||
|
||||
Format your response as a JSON object with the following structure:
|
||||
{{
|
||||
"prioritized_topics": ["topic1", "topic2", ...],
|
||||
"recommended_activities": {{
|
||||
"topic1": ["activity1", "activity2", ...],
|
||||
...
|
||||
}},
|
||||
"example_questions": [
|
||||
{{
|
||||
"topic": "topic1",
|
||||
"question": "...",
|
||||
"answer": "..."
|
||||
}},
|
||||
...
|
||||
],
|
||||
"overall_strategy": "..."
|
||||
}}
|
||||
"""
|
||||
|
||||
# Get response from LLM
|
||||
response = self._call_llm(prompt)
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
lesson_plan = self._parse_question_response(response)
|
||||
return lesson_plan
|
||||
except Exception as e:
|
||||
print(f"Failed to parse lesson plan response: {e}")
|
||||
return {
|
||||
"error": "Failed to generate valid lesson plan",
|
||||
"raw_response": response
|
||||
}
|
||||
|
||||
def load_student_profile(self, profile_dict):
|
||||
self.student_profile = profile_dict
|
||||
self.weak_areas = []
|
||||
self.question_history = []
|
||||
|
||||
def _build_prompt(self, topic, difficulty, goal):
|
||||
return (
|
||||
f"You are an expert teacher helping a student learn about '{topic}' at a '{difficulty}' level.\n"
|
||||
f"The student is learning for this goal: '{goal}'.\n"
|
||||
"Start with a clear, engaging multiple-choice question to assess their baseline understanding."
|
||||
)
|
||||
|
||||
def generate_initial_question(self):
|
||||
topic = self.student_profile.get("topic")
|
||||
difficulty = self.student_profile.get("difficulty")
|
||||
goal = self.student_profile.get("goal")
|
||||
|
||||
prompt = self._build_prompt(topic, difficulty, goal)
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful, adaptive teacher."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
question = response['choices'][0]['message']['content']
|
||||
self.question_history.append({"question": question, "type": "initial"})
|
||||
return question
|
||||
|
||||
def evaluate_student_response(self, question, student_answer):
|
||||
prompt = (
|
||||
f"You asked: {question}\n"
|
||||
f"The student answered: {student_answer}\n"
|
||||
"Evaluate the correctness. If incorrect, explain why in simple terms and identify the specific weak area."
|
||||
)
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an expert teacher scoring a student response."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
feedback = response['choices'][0]['message']['content']
|
||||
|
||||
# Optional: extract weak area via LLM or regex (simplified here)
|
||||
if "incorrect" in feedback.lower() or "not correct" in feedback.lower():
|
||||
self.weak_areas.append("unidentified weak topic")
|
||||
|
||||
self.question_history[-1]["student_answer"] = student_answer
|
||||
self.question_history[-1]["teacher_feedback"] = feedback
|
||||
|
||||
return feedback
|
||||
|
||||
def follow_up_on_weakness(self):
|
||||
if not self.weak_areas:
|
||||
return "Student has no identified weaknesses so far."
|
||||
|
||||
weak_topic = self.weak_areas[-1]
|
||||
prompt = (
|
||||
f"You are helping a student understand '{self.student_profile['topic']}' better.\n"
|
||||
f"They are struggling with: '{weak_topic}'.\n"
|
||||
"Explain this weak area briefly and follow up with another multiple-choice question."
|
||||
)
|
||||
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a tutor addressing student knowledge gaps."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
|
||||
question = response['choices'][0]['message']['content']
|
||||
self.question_history.append({"question": question, "type": "followup", "weak_topic": weak_topic})
|
||||
return question
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Path to student profile
|
||||
profile_path = "config/example_profile.json"
|
||||
|
||||
# Initialize teacher agent
|
||||
teacher = TeacherAgent(profile_path)
|
||||
|
||||
# Generate a question
|
||||
question = teacher.generate_question()
|
||||
print("Generated question:", question)
|
||||
|
||||
# Simulate student response (typically this would come from StudentAgent)
|
||||
student_answer = "A" # Placeholder
|
||||
|
||||
# Evaluate response
|
||||
evaluation = teacher.evaluate_response(0, student_answer)
|
||||
print("Evaluation:", evaluation)
|
||||
|
||||
# Get performance summary
|
||||
summary = teacher.get_performance_summary()
|
||||
print("Performance summary:", summary)
|
||||
|
||||
# Generate adaptive lesson plan
|
||||
lesson_plan = teacher.generate_adaptive_lesson_plan()
|
||||
print("Adaptive lesson plan:", lesson_plan)
|
||||
|
|
@ -0,0 +1,393 @@
|
|||
import os
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
|
||||
# Import the TeacherAgent class
|
||||
from teacher_agent import TeacherAgent
|
||||
|
||||
class TutorEnv(gym.Env):
|
||||
"""
|
||||
TutorEnv for the LLM-Based Interactive Teacher-Student Tutor Environment.
|
||||
|
||||
This environment follows the Atropos LanguageEnv pattern and is responsible for:
|
||||
1. Managing the interaction between TeacherAgent and StudentAgent
|
||||
2. Computing rewards based on student learning
|
||||
3. Tracking the state of the tutoring session
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["human"]}
|
||||
|
||||
def __init__(self, profile_path: str, render_mode: Optional[str] = None):
|
||||
"""
|
||||
Initialize the TutorEnv with a student profile.
|
||||
|
||||
Args:
|
||||
profile_path: Path to the JSON file containing student profile
|
||||
render_mode: Optional rendering mode
|
||||
"""
|
||||
# Load student profile
|
||||
with open(profile_path, 'r') as file:
|
||||
self.profile = json.load(file)
|
||||
|
||||
# Initialize TeacherAgent
|
||||
self.teacher_agent = TeacherAgent(profile_path)
|
||||
|
||||
# Initialize student metrics
|
||||
self.init_student_metrics()
|
||||
|
||||
# Set up action and observation spaces
|
||||
# For simplicity, we use a discrete action space with 4 actions:
|
||||
# 0: Ask easy question
|
||||
# 1: Ask medium question
|
||||
# 2: Ask hard question
|
||||
# 3: Generate adaptive lesson plan
|
||||
self.action_space = spaces.Discrete(4)
|
||||
|
||||
# Observation space will be a Dict of:
|
||||
# - student_performance: Box with performance metrics
|
||||
# - question_history: MultiBinary with question difficulty and correctness
|
||||
self.observation_space = spaces.Dict({
|
||||
'student_performance': spaces.Box(
|
||||
low=np.array([0, 0, 0, 0]), # [overall_accuracy, vectors_acc, matrices_acc, linear_systems_acc]
|
||||
high=np.array([1, 1, 1, 1]),
|
||||
dtype=np.float32
|
||||
),
|
||||
'question_history': spaces.MultiBinary(10) # Track last 10 questions (1=correct, 0=incorrect)
|
||||
})
|
||||
|
||||
# Track conversation state
|
||||
self.state = {
|
||||
'student_performance': np.zeros(4, dtype=np.float32),
|
||||
'question_history': np.zeros(10, dtype=np.int8)
|
||||
}
|
||||
|
||||
# Track additional metrics for reward calculation
|
||||
self.current_question = None
|
||||
self.question_count = 0
|
||||
self.correct_count = 0
|
||||
self.episode_reward = 0.0
|
||||
self.last_action = None
|
||||
self.history = []
|
||||
|
||||
# Set render mode
|
||||
self.render_mode = render_mode
|
||||
|
||||
def init_student_metrics(self):
|
||||
"""Initialize student metrics from profile."""
|
||||
self.student_metrics = {
|
||||
"overall_accuracy": self.profile.get("current_avg_score", 0) / 100,
|
||||
"topic_accuracies": {}
|
||||
}
|
||||
|
||||
# Set initial proficiency values from profile
|
||||
for topic in self.profile.get("topics", []):
|
||||
topic_name = topic.get("name")
|
||||
proficiency = topic.get("proficiency", 0.5)
|
||||
self.student_metrics["topic_accuracies"][topic_name] = proficiency
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""Reset the environment to initial state."""
|
||||
super().reset(seed=seed)
|
||||
|
||||
# Reset metrics
|
||||
self.init_student_metrics()
|
||||
self.question_count = 0
|
||||
self.correct_count = 0
|
||||
self.episode_reward = 0.0
|
||||
self.last_action = None
|
||||
self.history = []
|
||||
|
||||
# Reset state
|
||||
self.state = {
|
||||
'student_performance': np.array([
|
||||
self.student_metrics["overall_accuracy"],
|
||||
self.student_metrics["topic_accuracies"].get("vectors", 0.5),
|
||||
self.student_metrics["topic_accuracies"].get("matrices", 0.5),
|
||||
self.student_metrics["topic_accuracies"].get("linear_systems", 0.5)
|
||||
], dtype=np.float32),
|
||||
'question_history': np.zeros(10, dtype=np.int8)
|
||||
}
|
||||
|
||||
return self.state, {}
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
Take a step in the environment based on the action.
|
||||
|
||||
Args:
|
||||
action: Integer action from the action space
|
||||
|
||||
Returns:
|
||||
Tuple of (next_state, reward, done, truncated, info)
|
||||
"""
|
||||
self.last_action = action
|
||||
|
||||
# Process action
|
||||
if action == 3: # Generate adaptive lesson plan
|
||||
lesson_plan = self.teacher_agent.generate_adaptive_lesson_plan()
|
||||
reward = self._compute_lesson_plan_reward(lesson_plan)
|
||||
info = {"lesson_plan": lesson_plan}
|
||||
done = True # End episode after generating lesson plan
|
||||
else:
|
||||
# Map action to difficulty
|
||||
difficulty_map = {0: "easy", 1: "medium", 2: "hard"}
|
||||
difficulty = difficulty_map[action]
|
||||
|
||||
# Generate question
|
||||
question = self.teacher_agent.generate_question(difficulty=difficulty)
|
||||
self.current_question = question
|
||||
|
||||
# Simulate student answer (in a real implementation, this would come from StudentAgent)
|
||||
student_answer, is_correct = self._simulate_student_answer(question)
|
||||
|
||||
# Evaluate response
|
||||
evaluation = self.teacher_agent.evaluate_response(
|
||||
len(self.teacher_agent.history) - 2, # Index of the question in history
|
||||
student_answer
|
||||
)
|
||||
|
||||
# Update metrics
|
||||
self.question_count += 1
|
||||
if is_correct:
|
||||
self.correct_count += 1
|
||||
|
||||
# Update student metrics based on the result
|
||||
self._update_student_metrics(question["topic"], is_correct)
|
||||
|
||||
# Compute reward
|
||||
reward = self._compute_question_reward(question, is_correct)
|
||||
|
||||
# Update state
|
||||
self._update_state(is_correct)
|
||||
|
||||
# Add to history
|
||||
self.history.append({
|
||||
"action": action,
|
||||
"question": question,
|
||||
"student_answer": student_answer,
|
||||
"is_correct": is_correct,
|
||||
"reward": reward
|
||||
})
|
||||
|
||||
info = {
|
||||
"question": question,
|
||||
"student_answer": student_answer,
|
||||
"evaluation": evaluation,
|
||||
"is_correct": is_correct
|
||||
}
|
||||
|
||||
# Episode is done after 10 questions
|
||||
done = self.question_count >= 10
|
||||
|
||||
# Accumulate episode reward
|
||||
self.episode_reward += reward
|
||||
|
||||
# Return results
|
||||
truncated = False
|
||||
return self.state, reward, done, truncated, info
|
||||
|
||||
def _simulate_student_answer(self, question: Dict[str, Any]) -> Tuple[str, bool]:
|
||||
"""
|
||||
Simulate a student answering the question based on their profile.
|
||||
|
||||
In a real implementation, this would be replaced with StudentAgent.
|
||||
"""
|
||||
# Get topic and difficulty
|
||||
topic = question.get("topic")
|
||||
difficulty = question.get("difficulty")
|
||||
|
||||
# Get student proficiency for this topic
|
||||
proficiency = self.student_metrics["topic_accuracies"].get(topic, 0.5)
|
||||
|
||||
# Adjust probability of correct answer based on difficulty
|
||||
difficulty_factor = {
|
||||
"easy": 0.3, # +30% chance of getting it right
|
||||
"medium": 0.0, # no adjustment
|
||||
"hard": -0.2 # -20% chance of getting it right
|
||||
}
|
||||
|
||||
# Calculate probability of correct answer
|
||||
correct_prob = proficiency + difficulty_factor.get(difficulty, 0.0)
|
||||
correct_prob = max(0.1, min(0.9, correct_prob)) # Clamp between 0.1 and 0.9
|
||||
|
||||
# Determine if answer is correct
|
||||
import random
|
||||
is_correct = random.random() < correct_prob
|
||||
|
||||
if is_correct:
|
||||
# Return correct answer
|
||||
return question["correct_answer"], True
|
||||
else:
|
||||
# Return random incorrect answer
|
||||
options = list(question["options"].keys())
|
||||
options.remove(question["correct_answer"])
|
||||
return random.choice(options), False
|
||||
|
||||
def _update_student_metrics(self, topic: str, is_correct: bool):
|
||||
"""Update student metrics based on question result."""
|
||||
# Update overall accuracy
|
||||
self.student_metrics["overall_accuracy"] = self.correct_count / self.question_count
|
||||
|
||||
# Update topic accuracy with learning effect
|
||||
current_accuracy = self.student_metrics["topic_accuracies"].get(topic, 0.5)
|
||||
|
||||
if is_correct:
|
||||
# Small improvement when answering correctly
|
||||
new_accuracy = current_accuracy + 0.02
|
||||
else:
|
||||
# Larger improvement when learning from mistakes (assumes good explanations)
|
||||
new_accuracy = current_accuracy + 0.01
|
||||
|
||||
# Clamp accuracy between 0 and 1
|
||||
new_accuracy = max(0.0, min(1.0, new_accuracy))
|
||||
self.student_metrics["topic_accuracies"][topic] = new_accuracy
|
||||
|
||||
def _update_state(self, is_correct: bool):
|
||||
"""Update the environment state."""
|
||||
# Update performance metrics
|
||||
self.state['student_performance'] = np.array([
|
||||
self.student_metrics["overall_accuracy"],
|
||||
self.student_metrics["topic_accuracies"].get("vectors", 0.5),
|
||||
self.student_metrics["topic_accuracies"].get("matrices", 0.5),
|
||||
self.student_metrics["topic_accuracies"].get("linear_systems", 0.5)
|
||||
], dtype=np.float32)
|
||||
|
||||
# Update question history (shift left and add new result)
|
||||
self.state['question_history'] = np.roll(self.state['question_history'], -1)
|
||||
self.state['question_history'][-1] = 1 if is_correct else 0
|
||||
|
||||
def _compute_question_reward(self, question: Dict[str, Any], is_correct: bool) -> float:
|
||||
"""
|
||||
Compute reward for asking a question.
|
||||
|
||||
Reward factors:
|
||||
1. Base reward for correct answer
|
||||
2. Bonus for appropriate difficulty (challenging but achievable)
|
||||
3. Bonus for targeting weak topics
|
||||
"""
|
||||
# Base reward
|
||||
reward = 1.0 if is_correct else -0.5
|
||||
|
||||
# Get topic and difficulty
|
||||
topic = question.get("topic")
|
||||
difficulty = question.get("difficulty")
|
||||
|
||||
# Get student proficiency for this topic
|
||||
proficiency = self.student_metrics["topic_accuracies"].get(topic, 0.5)
|
||||
|
||||
# Compute difficulty appropriateness (reward for matching difficulty to proficiency)
|
||||
difficulty_values = {"easy": 0.3, "medium": 0.6, "hard": 0.9}
|
||||
difficulty_value = difficulty_values.get(difficulty, 0.6)
|
||||
|
||||
# Reward is highest when difficulty matches proficiency (within 0.2)
|
||||
difficulty_match = 1.0 - abs(difficulty_value - proficiency) * 2
|
||||
difficulty_bonus = max(0.0, difficulty_match) * 0.5
|
||||
|
||||
# Bonus for targeting weak topics (inverse of proficiency)
|
||||
weakness_bonus = (1.0 - proficiency) * 0.5
|
||||
|
||||
# Combine rewards
|
||||
total_reward = reward + difficulty_bonus + weakness_bonus
|
||||
|
||||
return total_reward
|
||||
|
||||
def _compute_lesson_plan_reward(self, lesson_plan: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Compute reward for generating an adaptive lesson plan.
|
||||
|
||||
Reward factors:
|
||||
1. Base reward for generating a plan
|
||||
2. Quality of topic prioritization (focus on weak areas)
|
||||
3. Diversity of recommended activities
|
||||
"""
|
||||
# Base reward
|
||||
reward = 1.0
|
||||
|
||||
# Check if there was an error generating the plan
|
||||
if "error" in lesson_plan:
|
||||
return -1.0
|
||||
|
||||
# Get prioritized topics
|
||||
prioritized_topics = lesson_plan.get("prioritized_topics", [])
|
||||
|
||||
# Check if weak topics are prioritized
|
||||
topic_accuracies = [(topic, acc) for topic, acc in
|
||||
self.student_metrics["topic_accuracies"].items()]
|
||||
topic_accuracies.sort(key=lambda x: x[1]) # Sort by accuracy (ascending)
|
||||
|
||||
weakest_topics = [topic for topic, _ in topic_accuracies[:2]]
|
||||
|
||||
# Count how many of the weakest topics are prioritized
|
||||
weak_topic_coverage = sum(1 for topic in prioritized_topics if topic in weakest_topics)
|
||||
weak_topic_bonus = weak_topic_coverage * 0.5
|
||||
|
||||
# Check diversity of activities
|
||||
activity_count = sum(len(activities) for activities in
|
||||
lesson_plan.get("recommended_activities", {}).values())
|
||||
activity_bonus = min(activity_count / 5, 1.0) * 0.5
|
||||
|
||||
# Count example questions
|
||||
example_count = len(lesson_plan.get("example_questions", []))
|
||||
example_bonus = min(example_count / 3, 1.0) * 0.5
|
||||
|
||||
# Total reward
|
||||
total_reward = reward + weak_topic_bonus + activity_bonus + example_bonus
|
||||
|
||||
return total_reward
|
||||
|
||||
def render(self):
|
||||
"""Render the environment."""
|
||||
if self.render_mode != "human":
|
||||
return
|
||||
|
||||
if self.last_action is None:
|
||||
print("\n=== TutorEnv: New session started ===")
|
||||
return
|
||||
|
||||
# Render based on last action
|
||||
if self.last_action == 3: # Lesson plan
|
||||
last_info = self.history[-1] if self.history else {}
|
||||
lesson_plan = last_info.get("lesson_plan", {})
|
||||
|
||||
print("\n=== Generated Adaptive Lesson Plan ===")
|
||||
print(f"Prioritized topics: {', '.join(lesson_plan.get('prioritized_topics', []))}")
|
||||
print("Recommended activities:")
|
||||
for topic, activities in lesson_plan.get("recommended_activities", {}).items():
|
||||
print(f" - {topic}: {', '.join(activities)}")
|
||||
print(f"Overall strategy: {lesson_plan.get('overall_strategy', 'N/A')}")
|
||||
else:
|
||||
# Render question and answer
|
||||
last_item = self.history[-1] if self.history else {}
|
||||
question = last_item.get("question", {})
|
||||
student_answer = last_item.get("student_answer", "")
|
||||
is_correct = last_item.get("is_correct", False)
|
||||
reward = last_item.get("reward", 0.0)
|
||||
|
||||
print("\n=== Question and Answer ===")
|
||||
print(f"Topic: {question.get('topic', 'N/A')}")
|
||||
print(f"Difficulty: {question.get('difficulty', 'N/A')}")
|
||||
print(f"Question: {question.get('question', 'N/A')}")
|
||||
print(f"Options:")
|
||||
for key, value in question.get("options", {}).items():
|
||||
print(f" {key}: {value}")
|
||||
print(f"Student answer: {student_answer}")
|
||||
print(f"Correct answer: {question.get('correct_answer', 'N/A')}")
|
||||
print(f"Result: {'Correct' if is_correct else 'Incorrect'}")
|
||||
print(f"Reward: {reward:.2f}")
|
||||
|
||||
# Print current metrics
|
||||
print("\n=== Current Metrics ===")
|
||||
print(f"Questions asked: {self.question_count}")
|
||||
print(f"Correct answers: {self.correct_count}")
|
||||
print(f"Overall accuracy: {self.student_metrics['overall_accuracy']:.2f}")
|
||||
print("Topic accuracies:")
|
||||
for topic, accuracy in self.student_metrics["topic_accuracies"].items():
|
||||
print(f" - {topic}: {accuracy:.2f}")
|
||||
print(f"Episode reward so far: {self.episode_reward:.2f}")
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
116
tutor_rl_agent/runner/run_loop.py
Normal file
116
tutor_rl_agent/runner/run_loop.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
import os
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Import our TutorEnv
|
||||
from tutor_env import TutorEnv
|
||||
|
||||
# This would be replaced with Atropos imports in a full implementation
|
||||
# For now, we'll simulate the RL loop
|
||||
class SimpleAgent:
|
||||
"""
|
||||
A simple agent that selects actions for the TutorEnv.
|
||||
This is a placeholder for an actual Atropos policy.
|
||||
"""
|
||||
|
||||
def __init__(self, action_space):
|
||||
"""Initialize with the action space of the environment."""
|
||||
self.action_space = action_space
|
||||
self.last_rewards = []
|
||||
self.action_values = np.ones(action_space.n) * 0.5 # Initialize values
|
||||
|
||||
def select_action(self, observation):
|
||||
"""
|
||||
Select an action based on the current observation.
|
||||
Uses simple epsilon-greedy strategy.
|
||||
"""
|
||||
# Exploration-exploitation trade-off
|
||||
epsilon = 0.2
|
||||
|
||||
if random.random() < epsilon:
|
||||
# Explore: random action
|
||||
return self.action_space.sample()
|
||||
else:
|
||||
# Exploit: best action based on current values
|
||||
return np.argmax(self.action_values)
|
||||
|
||||
def update(self, action, reward):
|
||||
"""Update action values based on reward."""
|
||||
# Simple update rule
|
||||
learning_rate = 0.1
|
||||
self.action_values[action] = (1 - learning_rate) * self.action_values[action] + learning_rate * reward
|
||||
self.last_rewards.append(reward)
|
||||
|
||||
def run_episode(env, agent, max_steps=10):
|
||||
"""Run a single episode of the environment."""
|
||||
observation, info = env.reset()
|
||||
done = False
|
||||
total_reward = 0
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < max_steps:
|
||||
# Select action
|
||||
action = agent.select_action(observation)
|
||||
|
||||
# Take step in environment
|
||||
next_observation, reward, done, truncated, info = env.step(action)
|
||||
|
||||
# Render environment (for human readability)
|
||||
env.render()
|
||||
|
||||
# Update agent
|
||||
agent.update(action, reward)
|
||||
|
||||
# Update tracking variables
|
||||
observation = next_observation
|
||||
total_reward += reward
|
||||
step_count += 1
|
||||
|
||||
return total_reward
|
||||
|
||||
def main():
|
||||
"""Main function to run the tutoring environment."""
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Check for API key
|
||||
api_key = os.getenv("NOUS_API_KEY")
|
||||
if not api_key:
|
||||
print("Warning: No NOUS_API_KEY found in environment variables.")
|
||||
print("Please set this key in your .env file.")
|
||||
return
|
||||
|
||||
# Path to student profile
|
||||
profile_path = "example_profile.json"
|
||||
|
||||
# Create environment
|
||||
env = TutorEnv(profile_path=profile_path, render_mode="human")
|
||||
|
||||
# Create agent
|
||||
agent = SimpleAgent(env.action_space)
|
||||
|
||||
# Run multiple episodes
|
||||
num_episodes = 5
|
||||
episode_rewards = []
|
||||
|
||||
print("\n=== Starting Training ===")
|
||||
for episode in range(num_episodes):
|
||||
print(f"\n=== Episode {episode + 1}/{num_episodes} ===")
|
||||
episode_reward = run_episode(env, agent)
|
||||
episode_rewards.append(episode_reward)
|
||||
print(f"\nEpisode {episode + 1} completed with total reward: {episode_reward:.2f}")
|
||||
|
||||
# Print training summary
|
||||
print("\n=== Training Summary ===")
|
||||
print(f"Average episode reward: {np.mean(episode_rewards):.2f}")
|
||||
print(f"Action values learned: {agent.action_values}")
|
||||
|
||||
# Close environment
|
||||
env.close()
|
||||
|
||||
print("\nDone!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue