Skip to content

Commit 40f3966

Browse files
committed
Add "full" prompt option
1 parent b3b8077 commit 40f3966

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

bigcode_eval/tasks/shadereval.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __init__(self):
131131
# and if the evaluation requires executing the generated code in `requires_execution`.
132132
stop_words=["\nfloat ", "\nvec", "\nint", "\nvoid", "\nmat"], #new function starts... so all the keywords
133133
requires_execution=True, #we run shadercode - could that be harmful? (all in the metric)
134+
prompt="minimal", # "minimal" or "full". "minimal" is the function header and comments before/after it, "full" is the whole code up untill the function declaration ends
134135
)
135136

136137
def get_dataset(self):
@@ -146,13 +147,19 @@ def get_prompt(self, doc):
146147
# TODO: build the prompt for the language model from a sample `doc` from the dataset
147148
"""
148149
Builds the prompt for the LM to generate from.
150+
if prompt == "minimal" -> function header and comments before/after it
151+
if prompt == "full" -> also includes full code before the function header
149152
:param doc: dict[str: str]
150153
sample from the test dataset
151154
:return: str
152155
"""
153-
154-
# alternatively, give the whole code up untill the function declaration ends? as in this paper: https://arxiv.org/abs/2306.03203
155-
return doc["model_ctx"]
156+
model_context = ""
157+
if self.prompt == "full":
158+
# alternatively, give the whole code up untill the function declaration ends? as in this paper: https://arxiv.org/abs/2306.03203
159+
model_context += doc["full_code"].encode("utf-8")[:doc["func_range"][0]].decode("utf-8") #returns full original code up untill the function declaration ends
160+
# only have one alternative, but could be more?
161+
model_context += doc["model_ctx"]
162+
return model_context
156163

157164
def get_reference(self, doc):
158165
# TODO: get the reference solution from a sample `doc` from the dataset
@@ -172,7 +179,7 @@ def remove_last_block(self, code):
172179
if w in code:
173180
code = code[:code.find(w)]
174181

175-
### Find the first occassion where a chain of { } is closed??
182+
### Find the first occassion where a chain of { } is closed??
176183
open_brackets = 1
177184
cut = False
178185
for i, c in enumerate(code.encode("utf-8")):
@@ -210,9 +217,13 @@ def postprocess_generation(self, generation, idx):
210217
model_ctx = ref["model_ctx"]
211218
full_code = ref["full_code"]
212219
start, end = ref["func_range"]
213-
gen = self.remove_last_block(generation.encode("utf-8")[len(model_ctx.encode("utf-8")):].decode("utf-8")) #remove last block to avoid syntax errors
214220
before_gen = full_code.encode("utf-8")[:start].decode("utf-8")
215221
after_gen = full_code.encode("utf-8")[end:].decode("utf-8")
222+
223+
if self.prompt == "full":
224+
gen = self.remove_last_block(generation.encode("utf-8")[start + len(model_ctx.encode("utf-8")):].decode("utf-8"))
225+
else:
226+
gen = self.remove_last_block(generation.encode("utf-8")[len(model_ctx.encode("utf-8")):].decode("utf-8")) #remove last block to avoid syntax errors
216227
return before_gen + model_ctx + gen + after_gen #does this patch it together correctly?
217228

218229
def process_results(self, generations, references):

0 commit comments

Comments
 (0)