Skip to content

Commit 72575f3

Browse files
authored
Merge pull request #132 from nnnyt/refactor
[REFACTOR] Update Strueture and Add Pipeline
2 parents 1cfc9ca + 139c1be commit 72575f3

File tree

151 files changed

+7894
-47202
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

151 files changed

+7894
-47202
lines changed

.gitignore

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,10 @@ venv.bak/
110110
.pyre/
111111

112112
# User Definition
113-
data/
113+
data/
114+
deprecated/
115+
tmp*/
116+
jieba.cache
117+
*.kv
118+
*.zip
119+
examples/test_model

AUTHORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@
1818

1919
[Yuting Ning](https://github.com/nnnyt)
2020

21+
[Jundong Wu](https://github.com/wintermelon008)
22+
2123

2224
The stared contributors are the corresponding authors.

CHANGE.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
v0.0.9
2+
1. Refactor tokenizer Basic Tokenizer and Pretrained Tokenizer
3+
2. Refactor model structures following huggingface styles for Elmo, BERT, DisenQNet and QuesNet
4+
3. Add PreprocessingPipeline and Pipeline
5+
4. Add downstream task: knowledge prediction and property prediction
6+
5. Fix a bug in RNN which causes ELMo not converging
7+
6. Move all the test models to modelhub
8+
7. Update test data files
9+
110
v0.0.8
211
1. add Emlo
312
2. add DisenQNet

EduNLP/I2V/i2v.py

Lines changed: 84 additions & 95 deletions
Large diffs are not rendered by default.

EduNLP/ModelZoo/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding: utf-8
2-
# 2021/7/12 @ tongshiwei
3-
41
from .utils import *
2+
from .bert import *
3+
from .rnn import *

EduNLP/ModelZoo/base_model.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch.nn as nn
2+
import json
3+
import os
4+
from pathlib import Path
5+
import torch
6+
from transformers import PretrainedConfig
7+
# import logging
8+
from ..utils import logger
9+
10+
11+
class BaseModel(nn.Module):
12+
base_model_prefix = ''
13+
14+
def __init__(self):
15+
super(BaseModel, self).__init__()
16+
self.config = PretrainedConfig()
17+
18+
def forward(self, *input):
19+
raise NotImplementedError
20+
21+
def save_pretrained(self, output_dir):
22+
if not os.path.exists(output_dir):
23+
os.makedirs(output_dir, exist_ok=True)
24+
model_path = os.path.join(output_dir, 'pytorch_model.bin')
25+
model_path = Path(model_path)
26+
torch.save(self.state_dict(), model_path.open('wb'))
27+
self.save_config(output_dir)
28+
29+
@classmethod
30+
def from_pretrained(cls, pretrained_model_path, *args, **kwargs):
31+
config_path = os.path.join(pretrained_model_path, "config.json")
32+
model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
33+
model = cls.from_config(config_path, *args, **kwargs)
34+
loaded_state_dict = torch.load(model_path)
35+
loaded_keys = loaded_state_dict.keys()
36+
expected_keys = model.state_dict().keys()
37+
38+
prefix = cls.base_model_prefix
39+
40+
if set(loaded_keys) == set(expected_keys):
41+
# same architecture
42+
model.load_state_dict(loaded_state_dict)
43+
else:
44+
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
45+
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
46+
47+
new_loaded_state_dict = {}
48+
if expects_prefix_module and not has_prefix_module:
49+
# add prefix
50+
for key in loaded_keys:
51+
new_loaded_state_dict['.'.join([prefix, key])] = loaded_state_dict[key]
52+
if has_prefix_module and not expects_prefix_module:
53+
# remove prefix
54+
for key in loaded_keys:
55+
if key.startswith(prefix):
56+
new_loaded_state_dict['.'.join(key.split('.')[1:])] = loaded_state_dict[key]
57+
if has_prefix_module and expects_prefix_module:
58+
# both have prefix, only load the base encoder
59+
for key in loaded_keys:
60+
if key.startswith(prefix):
61+
new_loaded_state_dict[key] = loaded_state_dict[key]
62+
loaded_state_dict = new_loaded_state_dict
63+
model.load_state_dict(loaded_state_dict, strict=False)
64+
loaded_keys = loaded_state_dict.keys()
65+
missing_keys = set(expected_keys) - set(loaded_keys)
66+
if len(missing_keys) == 0:
67+
logger.info(
68+
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
69+
f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
70+
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
71+
" training."
72+
)
73+
elif len(missing_keys) > 0:
74+
logger.warning(
75+
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
76+
f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
77+
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
78+
)
79+
return model
80+
81+
def save_config(self, config_dir):
82+
config_path = os.path.join(config_dir, "config.json")
83+
with open(config_path, "w", encoding="utf-8") as wf:
84+
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
85+
86+
@classmethod
87+
def from_config(cls, config_path, *args, **kwargs):
88+
raise NotImplementedError

EduNLP/ModelZoo/bert/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .bert import *

EduNLP/ModelZoo/bert/bert.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
4+
from baize.torch import load_net
5+
import torch.nn.functional as F
6+
import json
7+
import os
8+
from ..base_model import BaseModel
9+
from transformers.modeling_outputs import ModelOutput
10+
from transformers import BertModel, PretrainedConfig
11+
from typing import List, Optional
12+
from ..rnn.harnn import HAM
13+
14+
__all__ = ["BertForPropertyPrediction", "BertForKnowledgePrediction"]
15+
16+
17+
class BertForPPOutput(ModelOutput):
18+
loss: torch.FloatTensor = None
19+
logits: torch.FloatTensor = None
20+
21+
22+
class BertForPropertyPrediction(BaseModel):
23+
def __init__(self, pretrained_model_dir=None, head_dropout=0.5):
24+
super(BertForPropertyPrediction, self).__init__()
25+
self.bert = BertModel.from_pretrained(pretrained_model_dir)
26+
self.hidden_size = self.bert.config.hidden_size
27+
self.head_dropout = head_dropout
28+
self.dropout = nn.Dropout(head_dropout)
29+
self.classifier = nn.Linear(self.hidden_size, 1)
30+
self.sigmoid = nn.Sigmoid()
31+
self.criterion = nn.MSELoss()
32+
33+
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__"]}
34+
self.config['architecture'] = 'BertForPropertyPrediction'
35+
self.config = PretrainedConfig.from_dict(self.config)
36+
37+
def forward(self,
38+
input_ids=None,
39+
attention_mask=None,
40+
token_type_ids=None,
41+
labels=None):
42+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
43+
item_embeds = outputs.last_hidden_state[:, 0, :]
44+
item_embeds = self.dropout(item_embeds)
45+
46+
logits = self.sigmoid(self.classifier(item_embeds)).squeeze(1)
47+
loss = None
48+
if labels is not None:
49+
loss = self.criterion(logits, labels) if labels is not None else None
50+
return BertForPPOutput(
51+
loss=loss,
52+
logits=logits,
53+
)
54+
55+
@classmethod
56+
def from_config(cls, config_path, **kwargs):
57+
with open(config_path, "r", encoding="utf-8") as rf:
58+
model_config = json.load(rf)
59+
model_config.update(kwargs)
60+
return cls(
61+
pretrained_model_dir=model_config['pretrained_model_dir'],
62+
head_dropout=model_config.get("head_dropout", 0.5)
63+
)
64+
65+
# @classmethod
66+
# def from_pretrained(cls):
67+
# NotImplementedError
68+
# # 需要验证是否和huggingface的模型兼容
69+
70+
71+
class BertForKnowledgePrediction(BaseModel):
72+
def __init__(self,
73+
num_classes_list: List[int] = None,
74+
num_total_classes: int = None,
75+
pretrained_model_dir=None,
76+
head_dropout=0.5,
77+
flat_cls_weight=0.5,
78+
attention_unit_size=256,
79+
fc_hidden_size=512,
80+
beta=0.5,
81+
):
82+
super(BertForKnowledgePrediction, self).__init__()
83+
self.bert = BertModel.from_pretrained(pretrained_model_dir)
84+
self.hidden_size = self.bert.config.hidden_size
85+
self.head_dropout = head_dropout
86+
self.dropout = nn.Dropout(head_dropout)
87+
self.classifier = nn.Linear(self.hidden_size, 1)
88+
self.sigmoid = nn.Sigmoid()
89+
self.criterion = nn.MSELoss()
90+
self.flat_classifier = nn.Linear(self.hidden_size, num_total_classes)
91+
self.ham_classifier = HAM(
92+
num_classes_list=num_classes_list,
93+
num_total_classes=num_total_classes,
94+
sequence_model_hidden_size=self.bert.config.hidden_size,
95+
attention_unit_size=attention_unit_size,
96+
fc_hidden_size=fc_hidden_size,
97+
beta=beta,
98+
dropout_rate=head_dropout
99+
)
100+
self.flat_cls_weight = flat_cls_weight
101+
self.num_classes_list = num_classes_list
102+
self.num_total_classes = num_total_classes
103+
104+
self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__"]}
105+
self.config['architecture'] = 'BertForKnowledgePrediction'
106+
self.config = PretrainedConfig.from_dict(self.config)
107+
108+
def forward(self,
109+
input_ids=None,
110+
attention_mask=None,
111+
token_type_ids=None,
112+
labels=None):
113+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
114+
item_embeds = outputs.last_hidden_state[:, 0, :]
115+
item_embeds = self.dropout(item_embeds)
116+
tokens_embeds = outputs.last_hidden_state
117+
tokens_embeds = self.dropout(tokens_embeds)
118+
flat_logits = self.sigmoid(self.flat_classifier(item_embeds))
119+
ham_outputs = self.ham_classifier(tokens_embeds)
120+
ham_logits = self.sigmoid(ham_outputs.scores)
121+
logits = self.flat_cls_weight * flat_logits + (1 - self.flat_cls_weight) * ham_logits
122+
loss = None
123+
if labels is not None:
124+
labels = torch.sum(torch.nn.functional.one_hot(labels, num_classes=self.num_total_classes), dim=1)
125+
labels = labels.float()
126+
loss = self.criterion(logits, labels) if labels is not None else None
127+
return BertForPPOutput(
128+
loss=loss,
129+
logits=logits,
130+
)
131+
132+
@classmethod
133+
def from_config(cls, config_path, **kwargs):
134+
with open(config_path, "r", encoding="utf-8") as rf:
135+
model_config = json.load(rf)
136+
model_config.update(kwargs)
137+
return cls(
138+
pretrained_model_dir=model_config['pretrained_model_dir'],
139+
head_dropout=model_config.get("head_dropout", 0.5),
140+
num_classes_list=model_config.get('num_classes_list'),
141+
num_total_classes=model_config.get('num_total_classes'),
142+
flat_cls_weight=model_config.get('flat_cls_weight', 0.5),
143+
attention_unit_size=model_config.get('attention_unit_size', 256),
144+
fc_hidden_size=model_config.get('fc_hidden_size', 512),
145+
beta=model_config.get('beta', 0.5),
146+
)
147+
148+
# @classmethod
149+
# def from_pretrained(cls):
150+
# NotImplementedError
151+
# # 需要验证是否和huggingface的模型兼容
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# -*- coding: utf-8 -*-
22

3-
from .disenqnet import DisenQNet
3+
from .disenqnet import DisenQNet, DisenQNetForPreTraining

0 commit comments

Comments
 (0)