Skip to content

Commit

Permalink
[Eval] Use exact matching for Y/N and multi-choice when OPENAI_API_KE…
Browse files Browse the repository at this point in the history
…Y not set (open-compass#44)

* update

* handle GPT_API_KEY missing
  • Loading branch information
kennymckormick committed Jan 8, 2024
1 parent bc4b28b commit a8f2b84
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 30 deletions.
5 changes: 4 additions & 1 deletion vlmeval/evaluate/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,11 @@ def multiple_choice_eval(eval_file, dataset=None, model='chatgpt-0613', nproc=4,
model_name = 'gpt-3.5-turbo-0613'
if INTERNAL:
model = OpenAIWrapperInternal(model_name, verbose=verbose, retry=10)
else:
elif gpt_key_set():
model = OpenAIWrapper(model_name, verbose=verbose, retry=10)
else:
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None

logger.info(f'Evaluating {eval_file}')
result_file = eval_file.replace(f'.{suffix}', f'_{name_str}_result.pkl')
Expand Down
25 changes: 14 additions & 11 deletions vlmeval/evaluate/yes_or_no.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,21 @@ def YOrN_eval(eval_file, model='chatgpt-0613', nproc=4, verbose=False, dataset=N

if INTERNAL:
model = OpenAIWrapperInternal(model_name, verbose=verbose, retry=10)
else:
elif gpt_key_set():
model = OpenAIWrapper(model_name, verbose=verbose, retry=10)

lt = len(unknown)
lines = [unknown.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = list(unknown['index'])

if len(tups):
res = track_progress_rich(YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
for k, v in zip(indices, res):
ans_map[k] = v
else:
logger.error('OPENAI_API_KEY is not set properly, will use exact matching for evaluation')
model = None

if model is not None:
lt = len(unknown)
lines = [unknown.iloc[i] for i in range(lt)]
tups = [(model, line) for line in lines]
indices = list(unknown['index'])
if len(tups):
res = track_progress_rich(YOrN_auxeval, tups, nproc=nproc, chunksize=nproc, keys=indices, save=tmp_file)
for k, v in zip(indices, res):
ans_map[k] = v

data['extracted'] = [ans_map[x] for x in data['index']]
dump(data, storage)
Expand Down
30 changes: 12 additions & 18 deletions vlmeval/smp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flake8: noqa: F401, F403
# flake8: noqa: F401, F403
import abc
import argparse
import csv
Expand Down Expand Up @@ -29,6 +29,14 @@
from huggingface_hub import scan_cache_dir
import logging

def gpt_key_set():
openai_key = os.environ.get('OPENAI_API_KEY', None)
return isinstance(openai_key, str) and openai_key.startswith('sk-')

def apiok(wrapper):
s = wrapper.generate("Hello!")
return wrapper.fail_msg not in s

def isimg(s):
return osp.exists(s) or s.startswith('http')

Expand Down Expand Up @@ -219,24 +227,17 @@ def last_modified(pth):
def mmqa_display(question):
question = {k.lower(): v for k, v in question.items()}
keys = list(question.keys())
if 'index' in keys:
keys.remove('index')
keys.remove('image')
keys = [k for k in keys if k not in ['index', 'image']]

images = question['image']
if isinstance(images, str):
images = [images]

idx = 'XXX'
if 'index' in question:
idx = question.pop('index')
idx = question.pop('index', 'XXX')
print(f'INDEX: {idx}')

for im in images:
image = decode_base64_to_image(im)
w, h = image.size
ratio = 500 / h
image = image.resize((int(ratio * w), int(ratio * h)))
image = decode_base64_to_image(im, target_size=512)
display(image)

for k in keys:
Expand Down Expand Up @@ -289,13 +290,6 @@ def mwlines(lines, fname):
with open(fname, 'w') as fout:
fout.write('\n'.join(lines))

def default_set(self, args, name, default):
if hasattr(args, name):
val = getattr(args, name)
setattr(self, name, val)
else:
setattr(self, name, default)

def dict_merge(dct, merge_dct):
for k, _ in merge_dct.items():
if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
Expand Down

0 comments on commit a8f2b84

Please sign in to comment.