mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-30 17:40:36 +00:00
add eval runner
This commit is contained in:
parent
405efa8302
commit
8ec5066998
3 changed files with 410 additions and 8 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
|
|
@ -352,12 +353,18 @@ class ServerManager:
|
|||
most_available_server_num_slots = server.sem._value
|
||||
|
||||
# Create ManagedServer wrapping the selected server
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
if isinstance(self.servers[most_available_server], OpenAIServer):
|
||||
warnings.warn(
|
||||
"Using OpenAIServer with managed_server does not allow for state tracking"
|
||||
)
|
||||
yield self.servers[most_available_server]
|
||||
else:
|
||||
managed = ManagedServer(
|
||||
server=self.servers[most_available_server], tokenizer=tokenizer
|
||||
)
|
||||
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
# Clean up: reset tracked sequences
|
||||
managed.reset()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue