Model¶
-
class
Manteia.Model.
EarlyStopping
(patience=2, delta=0, path=None, verbose=True)¶ Early stops the training if validation loss doesn’t improve after a given patience.
-
save_checkpoint
(acc_validation, model, device_model)¶ Saves model when validation loss decrease.
-
-
class
Manteia.Model.
Model
(model_name='bert', model_type=None, task='classification', num_labels=0, epochs=None, MAX_SEQ_LEN=128, early_stopping=False, path='./model/', verbose=True)¶ This is the class to construct model.
Args:
- model_name (
string
, optional, defaults to ‘bert’): give the name of a model.
- num_labels (
int
, optional, defaults to ‘0’): give the number of categorie for classification.
Example:
from Manteia.Preprocess import Preprocess from Manteia.Model import Model,encode_text,encode_label,Create_DataLoader_train from sklearn.model_selection import train_test_split documents=['a text','text b'] labels=['a','b'] pp = Preprocess(documents=documents,labels=labels) model = Model(model_name=model_name,num_labels=len(pp.list_labels)) model.load() train_text, validation_text, train_labels, validation_labels = train_test_split(pp.documents, pp.labels, random_state=2018, test_size=0.1) train_ids,train_masks = encode_text(train_text,model.tokenizer,MAX_SEQ_LEN) validation_ids,validation_masks = encode_text(validation_text,model.tokenizer,MAX_SEQ_LEN) train_labels = encode_label(train_labels,pp.list_labels) validation_labels = encode_label(validation_labels,pp.list_labels) dt_train = Create_DataLoader_train(train_ids,train_masks,train_labels) dt_validation = Create_DataLoader_train(validation_ids,validation_masks,validation_labels) model.configuration(dt_train) model.fit(dt_train,dt_validation)
Attributes:
-
predict
(predict_dataloader, p_type='class', mode='eval')¶ - if self.early_stopping:
#by torch #pour charger uniquement la classe du modèle! print(‘test’) self.load_type() print(‘test’) self.load_class() print(‘test’) self.model.load_state_dict(torch.load(os.path.join(self.path,’state_dict_validation.pt’))) print(‘test’)
#by transformer #self.model.from_pretrained(self.path) if self.verbose==True:
print(‘loading model early…’)
- model_name (
-
Manteia.Model.
format_time
(elapsed)¶ Takes a time in seconds and returns a string hh:mm:ss