mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
22884d2bf7
commit
d84e3c70b7
16 changed files with 270 additions and 143 deletions
|
|
@ -40,12 +40,16 @@ class SEEDBench2Plus(EvalBase):
|
|||
except Exception as e:
|
||||
print(f"Warning: Could not load SEED-Bench2: {e}")
|
||||
try:
|
||||
dataset = load_dataset("lmms-lab/SEED-Bench", split=split, streaming=True)
|
||||
dataset = load_dataset(
|
||||
"lmms-lab/SEED-Bench", split=split, streaming=True
|
||||
)
|
||||
if max_samples:
|
||||
data = list(dataset.take(max_samples))
|
||||
else:
|
||||
data = list(dataset.take(1000))
|
||||
print(f"Loaded {len(data)} examples from SEED-Bench ({split}, streaming)")
|
||||
print(
|
||||
f"Loaded {len(data)} examples from SEED-Bench ({split}, streaming)"
|
||||
)
|
||||
return data
|
||||
except Exception:
|
||||
raise ValueError(f"Could not load SEED-Bench2-Plus dataset: {e}")
|
||||
|
|
@ -103,15 +107,19 @@ class SEEDBench2Plus(EvalBase):
|
|||
|
||||
content = []
|
||||
if image_base64:
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
})
|
||||
content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
|
||||
}
|
||||
)
|
||||
content.append({"type": "text", "text": prompt})
|
||||
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
def extract_answer(self, response: str, num_choices: int) -> Tuple[Optional[str], str]:
|
||||
def extract_answer(
|
||||
self, response: str, num_choices: int
|
||||
) -> Tuple[Optional[str], str]:
|
||||
valid_letters = set(ascii_uppercase[:num_choices])
|
||||
|
||||
letter, method = extract_letter_from_answer_tag(response, valid_letters)
|
||||
|
|
@ -154,7 +162,8 @@ class SEEDBench2Plus(EvalBase):
|
|||
num_choices = len(choices) if choices else 4
|
||||
if num_choices == 0:
|
||||
num_choices = sum(
|
||||
1 for letter in ascii_uppercase[:6]
|
||||
1
|
||||
for letter in ascii_uppercase[:6]
|
||||
if letter in data_item and data_item[letter] is not None
|
||||
)
|
||||
num_choices = max(num_choices, 4)
|
||||
|
|
@ -168,7 +177,9 @@ class SEEDBench2Plus(EvalBase):
|
|||
sample = {
|
||||
"id": data_item.get("index", data_item.get("question_id", "")),
|
||||
"question": data_item.get("question", "")[:200],
|
||||
"category": data_item.get("question_type_id", data_item.get("category", "")),
|
||||
"category": data_item.get(
|
||||
"question_type_id", data_item.get("category", "")
|
||||
),
|
||||
"answer": answer,
|
||||
"prediction": extracted,
|
||||
"raw_response": response[:500],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue