mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
fixed linting in latest main
This commit is contained in:
parent
00dd120067
commit
c72a27d376
2 changed files with 42 additions and 39 deletions
|
|
@ -1,19 +1,14 @@
|
|||
import random
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
import regex as re
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
import docker, os
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataGroup
|
||||
from atroposlib.type_definitions import Item, number
|
||||
import regex as re
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, ScoredDataGroup
|
||||
from atroposlib.type_definitions import GameHistory, Item
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
|
|
@ -25,17 +20,15 @@ system_prompt = (
|
|||
"tags, and then provide your solution or response to the problem.\n\n"
|
||||
)
|
||||
|
||||
|
||||
async def submit_code(client, code, test_input, language="python"):
|
||||
url = "http://localhost:5002/execute"
|
||||
payload = {
|
||||
"code": code,
|
||||
"input": test_input,
|
||||
"language": language
|
||||
}
|
||||
payload = {"code": code, "input": test_input, "language": language}
|
||||
response = await client.post(url, json=payload)
|
||||
response_json = response.json()
|
||||
return response_json["output"]
|
||||
|
||||
|
||||
async def get_results(code, answer):
|
||||
async with httpx.AsyncClient() as client:
|
||||
tasks = []
|
||||
|
|
@ -45,18 +38,22 @@ async def get_results(code, answer):
|
|||
results = await asyncio.gather(*tasks)
|
||||
return [result for result in results]
|
||||
|
||||
|
||||
def init_docker():
|
||||
client = docker.from_env()
|
||||
|
||||
def build_docker_image():
|
||||
try:
|
||||
# Build the Docker image
|
||||
print("Building Docker image...")
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__)) # Get the current directory of the script
|
||||
current_dir = os.path.dirname(
|
||||
os.path.abspath(__file__)
|
||||
) # Get the current directory of the script
|
||||
image, logs = client.images.build(path=current_dir, tag="code-executor")
|
||||
|
||||
# Print the build logs
|
||||
for log in logs:
|
||||
print(log.get('stream', '').strip())
|
||||
print(log.get("stream", "").strip())
|
||||
|
||||
print("Docker image built successfully.")
|
||||
return image
|
||||
|
|
@ -67,19 +64,20 @@ def init_docker():
|
|||
try:
|
||||
# Run the Docker container
|
||||
print("Running Docker container...")
|
||||
container = client.containers.run("code-executor",
|
||||
ports={'5002/tcp': 5002},
|
||||
detach=True) # Runs in detached mode (in the background)
|
||||
container = client.containers.run(
|
||||
"code-executor", ports={"5002/tcp": 5002}, detach=True
|
||||
) # Runs in detached mode (in the background)
|
||||
|
||||
print(f"Docker container is running with ID: {container.id}")
|
||||
return container
|
||||
except docker.errors.ContainerError as e:
|
||||
print(f"Error during Docker container run: {e}")
|
||||
|
||||
|
||||
build_docker_image()
|
||||
container = run_docker_container()
|
||||
return container
|
||||
|
||||
|
||||
class CodingEnv(BaseEnv):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
|
@ -111,7 +109,7 @@ class CodingEnv(BaseEnv):
|
|||
item[1],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
to_postprocess = await self.score(to_score)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
|
|
@ -144,18 +142,23 @@ class CodingEnv(BaseEnv):
|
|||
prompt = tuple(
|
||||
[frozenset({"role": "user", "content": next_item["description"]}.items())]
|
||||
)
|
||||
answer = (tuple(next_item["private_tests"]["input"]), tuple(next_item["private_tests"]["output"]), tuple(next_item["generated_tests"]["input"]), tuple(next_item["generated_tests"]["output"]))
|
||||
answer = (
|
||||
tuple(next_item["private_tests"]["input"]),
|
||||
tuple(next_item["private_tests"]["output"]),
|
||||
tuple(next_item["generated_tests"]["input"]),
|
||||
tuple(next_item["generated_tests"]["output"]),
|
||||
)
|
||||
return (prompt, answer)
|
||||
|
||||
def extract_python_code_blocks(self, text):
|
||||
# Regex specifically looks for ```python\n...code...\n```
|
||||
pattern = r'^```(?:\w+)?\s*\n(.*?)(?=^```)```'
|
||||
# Regex specifically looks for ```python\n...code...\n```
|
||||
pattern = r"^```(?:\w+)?\s*\n(.*?)(?=^```)```"
|
||||
result = re.findall(pattern, text, re.DOTALL | re.MULTILINE)
|
||||
python_blocks = [r for r in result]
|
||||
return python_blocks
|
||||
|
||||
async def score(self, rollout_group_data) -> Optional[ScoredDataGroup]:
|
||||
#print("Rollout group data", rollout_group_data)
|
||||
# print("Rollout group data", rollout_group_data)
|
||||
scores = ScoredDataGroup()
|
||||
scores["tokens"] = list()
|
||||
scores["masks"] = list()
|
||||
|
|
@ -192,5 +195,6 @@ class CodingEnv(BaseEnv):
|
|||
# return None # If all the same, we return None
|
||||
return scores
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CodingEnv.cli()
|
||||
|
|
|
|||
|
|
@ -14,15 +14,17 @@ curl -X POST http://localhost:5002/execute \
|
|||
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
from flask import Flask, request, jsonify
|
||||
import time
|
||||
|
||||
from flask import Flask, jsonify, request
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/execute', methods=['POST'])
|
||||
|
||||
@app.route("/execute", methods=["POST"])
|
||||
def execute_code():
|
||||
try:
|
||||
# Receive C++ code from API request
|
||||
|
|
@ -46,16 +48,13 @@ def execute_code():
|
|||
input=test_cases,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5 # Prevent infinite loops
|
||||
timeout=5, # Prevent infinite loops
|
||||
)
|
||||
|
||||
# Cleanup temporary files
|
||||
os.remove(py_filename)
|
||||
|
||||
return jsonify({
|
||||
"output": exec_result.stdout,
|
||||
"error": exec_result.stderr
|
||||
})
|
||||
return jsonify({"output": exec_result.stdout, "error": exec_result.stderr})
|
||||
""" C++ stuff
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
|
|
@ -90,6 +89,6 @@ def execute_code():
|
|||
except Exception as e:
|
||||
return jsonify({"error": str(e)})
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host="0.0.0.0", port=5002)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=5002)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue