Skip to content

Commit 40a2d26

Browse files
* Set up fixtures and data for tests Signed-off-by: Thara Palanivel <[email protected]> * Add basic unit tests Signed-off-by: Thara Palanivel <[email protected]> * Setting upper bound for transformers Signed-off-by: Thara Palanivel <[email protected]> * Ignore aim log files Signed-off-by: Thara Palanivel <[email protected]> * Include int num_train_epochs Signed-off-by: Thara Palanivel <[email protected]> * Fix formatting Signed-off-by: Thara Palanivel <[email protected]> * Add copyright notice Signed-off-by: Thara Palanivel <[email protected]> * Address review comments Signed-off-by: Thara Palanivel <[email protected]> * Run inference on tuned model Signed-off-by: Thara Palanivel <[email protected]> * Trainer downloads model Signed-off-by: Thara Palanivel <[email protected]> * add more unit tests and refactor Signed-off-by: Anh-Uong <[email protected]> * Fix formatting Signed-off-by: Thara Palanivel <[email protected]> * Add FT unit test and refactor Signed-off-by: Thara Palanivel <[email protected]> * Removing transformers upper bound cap Signed-off-by: Thara Palanivel <[email protected]> * Address review comments Signed-off-by: Thara Palanivel <[email protected]> --------- Signed-off-by: Thara Palanivel <[email protected]> Signed-off-by: Anh-Uong <[email protected]> Co-authored-by: Anh-Uong <[email protected]>
1 parent d8536a9 commit 40a2d26

File tree

7 files changed

+421
-1
lines changed

7 files changed

+421
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ venv/
2525

2626
# Tox envs
2727
.tox
28+
29+
# Aim
30+
.aim

tests/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright The IBM Tuning Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

tests/data/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright The IBM Tuning Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Helpful datasets for configuring individual unit tests.
16+
"""
17+
# Standard
18+
import os
19+
20+
### Constants used for data
21+
DATA_DIR = os.path.join(os.path.dirname(__file__))
22+
TWITTER_COMPLAINTS_DATA = os.path.join(DATA_DIR, "twitter_complaints_small.json")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{"Tweet text":"@HMRCcustomers No this is my first job","ID":0,"Label":2,"text_label":"no complaint","output":"### Text: @HMRCcustomers No this is my first job\n\n### Label: no complaint"}
2+
{"Tweet text":"@KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.","ID":1,"Label":2,"text_label":"no complaint","output":"### Text: @KristaMariePark Thank you for your interest! If you decide to cancel, you can call Customer Care at 1-800-NYTIMES.\n\n### Label: no complaint"}
3+
{"Tweet text":"If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService","ID":2,"Label":1,"text_label":"complaint","output":"### Text: If I can't get my 3rd pair of @beatsbydre powerbeats to work today I'm doneski man. This is a slap in my balls. Your next @Bose @BoseService\n\n### Label: complaint"}
4+
{"Tweet text":"@EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.","ID":3,"Label":1,"text_label":"complaint","output":"### Text: @EE On Rosneath Arial having good upload and download speeds but terrible latency 200ms. Why is this.\n\n### Label: complaint"}
5+
{"Tweet text":"Couples wallpaper, so cute. :) #BrothersAtHome","ID":4,"Label":2,"text_label":"no complaint","output":"### Text: Couples wallpaper, so cute. :) #BrothersAtHome\n\n### Label: no complaint"}
6+
{"Tweet text":"@mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG","ID":5,"Label":2,"text_label":"no complaint","output":"### Text: @mckelldogs This might just be me, but-- eyedrops? Artificial tears are so useful when you're sleep-deprived and sp\u2026 https:\/\/t.co\/WRtNsokblG\n\n### Label: no complaint"}
7+
{"Tweet text":"@Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?","ID":6,"Label":2,"text_label":"no complaint","output":"### Text: @Yelp can we get the exact calculations for a business rating (for example if its 4 stars but actually 4.2) or do we use a 3rd party site?\n\n### Label: no complaint"}
8+
{"Tweet text":"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?","ID":7,"Label":1,"text_label":"complaint","output":"### Text: @nationalgridus I have no water and the bill is current and paid. Can you do something about this?\n\n### Label: complaint"}
9+
{"Tweet text":"Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora","ID":8,"Label":1,"text_label":"complaint","output":"### Text: Never shopping at @MACcosmetics again. Every time I go in there, their employees are super rude\/condescending. I'll take my $$ to @Sephora\n\n### Label: complaint"}
10+
{"Tweet text":"@JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd","ID":9,"Label":2,"text_label":"no complaint","output":"### Text: @JenniferTilly Merry Christmas to as well. You get more stunning every year \ufffd\ufffd\n\n### Label: no complaint"}

tests/helpers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The IBM Tuning Team
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Third Party
16+
import transformers
17+
18+
# Local
19+
from tuning.config import configs, peft_config
20+
21+
22+
def causal_lm_train_kwargs(train_kwargs):
23+
"""Parse the kwargs for a valid train call to a Causal LM."""
24+
parser = transformers.HfArgumentParser(
25+
dataclass_types=(
26+
configs.ModelArguments,
27+
configs.DataArguments,
28+
configs.TrainingArguments,
29+
peft_config.LoraConfig,
30+
peft_config.PromptTuningConfig,
31+
)
32+
)
33+
(
34+
model_args,
35+
data_args,
36+
training_args,
37+
lora_config,
38+
prompt_tuning_config,
39+
) = parser.parse_dict(train_kwargs, allow_extra_keys=True)
40+
return (
41+
model_args,
42+
data_args,
43+
training_args,
44+
lora_config
45+
if train_kwargs.get("peft_method") == "lora"
46+
else prompt_tuning_config,
47+
)

0 commit comments

Comments
 (0)