Source code for farm.modeling.adaptive_model

import copy
import json
import logging
import os
from pathlib import Path

import multiprocessing
import numpy
import torch
from torch import nn
from transformers import AutoConfig
from transformers.convert_graph_to_onnx import convert, quantize as quantize_model


from farm.data_handler.processor import Processor
from farm.modeling.language_model import LanguageModel
from farm.modeling.prediction_head import PredictionHead, pick_single_fn
from farm.utils import MLFlowLogger as MlLogger, stack
import farm.conversion.transformers as conv

logger = logging.getLogger(__name__)


[docs]class BaseAdaptiveModel: """ Base Class for implementing AdaptiveModel with frameworks like PyTorch and ONNX. """ subclasses = {} def __init_subclass__(cls, **kwargs): """ This automatically keeps track of all available subclasses. Enables generic load() for all specific AdaptiveModel implementation. """ super().__init_subclass__(**kwargs) cls.subclasses[cls.__name__] = cls
[docs] def __init__(self, prediction_heads): self.prediction_heads = prediction_heads
[docs] @classmethod def load(cls, **kwargs): """ Load corresponding AdaptiveModel Class(AdaptiveModel/ONNXAdaptiveModel) based on the files in the load_dir. :param kwargs: arguments to pass for loading the model. :return: instance of a model """ if (Path(kwargs["load_dir"]) / "model.onnx").is_file(): model = cls.subclasses["ONNXAdaptiveModel"].load(**kwargs) else: model = cls.subclasses["AdaptiveModel"].load(**kwargs) return model
[docs] def logits_to_preds(self, logits, **kwargs): """ Get predictions from all prediction heads. :param logits: logits, can vary in shape and type, depending on task :type logits: object :param label_maps: Maps from label encoding to label string :param label_maps: dict :return: A list of all predictions from all prediction heads """ all_preds = [] # collect preds from all heads for head, logits_for_head in zip(self.prediction_heads, logits): preds = head.logits_to_preds(logits=logits_for_head, **kwargs) all_preds.append(preds) return all_preds
[docs] def formatted_preds(self, logits, **kwargs): """ Format predictions for inference. :param logits: model logits :type logits: torch.tensor :param kwargs: placeholder for passing generic parameters :type kwargs: object :return: predictions in the right format """ n_heads = len(self.prediction_heads) if n_heads == 0: # just return LM output (e.g. useful for extracting embeddings at inference time) preds_final = self.language_model.formatted_preds(logits=logits, **kwargs) elif n_heads == 1: preds_final = [] # This try catch is to deal with the fact that sometimes we collect preds before passing it to # formatted_preds (see Inferencer._get_predictions_and_aggregate()) and sometimes we don't # (see Inferencer._get_predictions()) try: preds = kwargs["preds"] temp = [y[0] for y in preds] preds_flat = [item for sublist in temp for item in sublist] kwargs["preds"] = preds_flat except KeyError: kwargs["preds"] = None head = self.prediction_heads[0] logits_for_head = logits[0] preds = head.formatted_preds(logits=logits_for_head, **kwargs) # TODO This is very messy - we need better definition of what the output should look like if type(preds) == list: preds_final += preds elif type(preds) == dict and "predictions" in preds: preds_final.append(preds) # This case is triggered by Natural Questions else: preds_final = [list() for _ in range(n_heads)] preds = kwargs["preds"] preds_for_heads = stack(preds) logits_for_heads = [None] * n_heads samples = [s for b in kwargs["baskets"] for s in b.samples] kwargs["samples"] = samples del kwargs["preds"] for i, (head, preds_for_head, logits_for_head) in enumerate(zip(self.prediction_heads, preds_for_heads, logits_for_heads)): preds = head.formatted_preds(logits=logits_for_head, preds=preds_for_head, **kwargs) preds_final[i].append(preds) # Look for a merge() function amongst the heads and if a single one exists, apply it to preds_final merge_fn = pick_single_fn(self.prediction_heads, "merge_formatted_preds") if merge_fn: preds_final = merge_fn(preds_final) return preds_final
[docs] def connect_heads_with_processor(self, tasks, require_labels=True): """ Populates prediction head with information coming from tasks. :param tasks: A dictionary where the keys are the names of the tasks and the values are the details of the task (e.g. label_list, metric, tensor name) :param require_labels: If True, an error will be thrown when a task is not supplied with labels) :return: """ # Drop the next sentence prediction head if it does not appear in tasks. This is triggered by the interaction # setting the argument BertStyleLMProcessor(next_sent_pred=False) if "nextsentence" not in tasks: idx = None for i, ph in enumerate(self.prediction_heads): if ph.task_name == "nextsentence": idx = i if idx is not None: logger.info( "Removing the NextSentenceHead since next_sent_pred is set to False in the BertStyleLMProcessor") del self.prediction_heads[i] for head in self.prediction_heads: head.label_tensor_name = tasks[head.task_name]["label_tensor_name"] label_list = tasks[head.task_name]["label_list"] if not label_list and require_labels: raise Exception(f"The task \'{head.task_name}\' is missing a valid set of labels") label_list = tasks[head.task_name]["label_list"] head.label_list = label_list if "RegressionHead" in str(type(head)): # This needs to be explicitly set because the regression label_list is being hijacked to store # the scaling factor and the mean num_labels = 1 else: num_labels = len(label_list) head.metric = tasks[head.task_name]["metric"]
@classmethod def _get_prediction_head_files(cls, load_dir, strict=True): load_dir = Path(load_dir) files = os.listdir(load_dir) model_files = [ load_dir / f for f in files if ".bin" in f and "prediction_head" in f ] config_files = [ load_dir / f for f in files if "config.json" in f and "prediction_head" in f ] # sort them to get correct order in case of multiple prediction heads model_files.sort() config_files.sort() if strict: error_str = ( f"There is a mismatch in number of model files ({len(model_files)}) and config files ({len(config_files)})." "This might be because the Language Model Prediction Head " "does not currently support saving and loading" ) assert len(model_files) == len(config_files), error_str logger.info(f"Found files for loading {len(model_files)} prediction heads") return model_files, config_files
[docs]def loss_per_head_sum(loss_per_head, global_step=None, batch=None): """ Input: loss_per_head (list of tensors), global_step (int), batch (dict) Output: aggregated loss (tensor) """ return sum(loss_per_head)
[docs]class AdaptiveModel(nn.Module, BaseAdaptiveModel): """ PyTorch implementation containing all the modelling needed for your NLP task. Combines a language model and a prediction head. Allows for gradient flow back to the language model component."""
[docs] def __init__( self, language_model, prediction_heads, embeds_dropout_prob, lm_output_types, device, loss_aggregation_fn=None, ): """ :param language_model: Any model that turns token ids into vector representations :type language_model: LanguageModel :param prediction_heads: A list of models that take embeddings and return logits for a given task :type prediction_heads: list :param embeds_dropout_prob: The probability that a value in the embeddings returned by the language model will be zeroed. :param embeds_dropout_prob: float :param lm_output_types: How to extract the embeddings from the final layer of the language model. When set to "per_token", one embedding will be extracted per input token. If set to "per_sequence", a single embedding will be extracted to represent the full input sequence. Can either be a single string, or a list of strings, one for each prediction head. :type lm_output_types: list or str :param device: The device on which this model will operate. Either "cpu" or "cuda". :param loss_aggregation_fn: Function to aggregate the loss of multiple prediction heads. Input: loss_per_head (list of tensors), global_step (int), batch (dict) Output: aggregated loss (tensor) Default is a simple sum: `lambda loss_per_head, global_step=None, batch=None: sum(tensors)` However, you can pass more complex functions that depend on the current step (e.g. for round-robin style multitask learning) or the actual content of the batch (e.g. certain labels) Note: The loss at this stage is per sample, i.e one tensor of shape (batchsize) per prediction head. :type loss_aggregation_fn: function """ super(AdaptiveModel, self).__init__() self.device = device self.language_model = language_model.to(device) self.lm_output_dims = language_model.get_output_dims() self.prediction_heads = nn.ModuleList([ph.to(device) for ph in prediction_heads]) self.fit_heads_to_lm() # set shared weights for LM finetuning for head in self.prediction_heads: if head.model_type == "language_modelling": head.set_shared_weights(language_model.model.embeddings.word_embeddings.weight) self.dropout = nn.Dropout(embeds_dropout_prob) self.lm_output_types = ( [lm_output_types] if isinstance(lm_output_types, str) else lm_output_types ) self.log_params() # default loss aggregation function is a simple sum (without using any of the optional params) if not loss_aggregation_fn: loss_aggregation_fn = loss_per_head_sum self.loss_aggregation_fn = loss_aggregation_fn
[docs] def fit_heads_to_lm(self): """This iterates over each prediction head and ensures that its input dimensionality matches the output dimensionality of the language model. If it doesn't, it is resized so it does fit""" for ph in self.prediction_heads: ph.resize_input(self.lm_output_dims) ph.to(self.device)
[docs] def bypass_ph(self): """Replaces methods in the prediction heads with dummy functions. Used for benchmarking where we want to isolate the lm run time from ph run time.""" def fake_forward(x): """Slices lm vector outputs of shape (batch_size, max_seq_len, dims) --> (batch_size, max_seq_len, 2)""" return x.narrow(2, 0, 2) def fake_logits_to_preds(logits, **kwargs): batch_size = logits.shape[0] return [None, None] * batch_size def fake_formatted_preds(**kwargs): return None for ph in self.prediction_heads: ph.forward = fake_forward ph.logits_to_preds = fake_logits_to_preds ph.formatted_preds = fake_formatted_preds
[docs] def save(self, save_dir): """ Saves the language model and prediction heads. This will generate a config file and model weights for each. :param save_dir: path to save to :type save_dir: Path """ os.makedirs(save_dir, exist_ok=True) self.language_model.save(save_dir) for i, ph in enumerate(self.prediction_heads): ph.save(save_dir, i)
# Need to save config and pipeline
[docs] @classmethod def load(cls, load_dir, device, strict=True, lm_name=None, processor=None): """ Loads an AdaptiveModel from a directory. The directory must contain: * language_model.bin * language_model_config.json * prediction_head_X.bin multiple PH possible * prediction_head_X_config.json * processor_config.json config for transforming input * vocab.txt vocab file for language model, turning text to Wordpiece Tokens :param load_dir: location where adaptive model is stored :type load_dir: Path :param device: to which device we want to sent the model, either cpu or cuda :type device: torch.device :param lm_name: the name to assign to the loaded language model :type lm_name: str :param strict: whether to strictly enforce that the keys loaded from saved model match the ones in the PredictionHead (see torch.nn.module.load_state_dict()). Set to `False` for backwards compatibility with PHs saved with older version of FARM. :type strict: bool :param processor: populates prediction head with information coming from tasks :type processor: Processor """ # Language Model if lm_name: language_model = LanguageModel.load(load_dir, farm_lm_name=lm_name) else: language_model = LanguageModel.load(load_dir) # Prediction heads _, ph_config_files = cls._get_prediction_head_files(load_dir) prediction_heads = [] ph_output_type = [] for config_file in ph_config_files: head = PredictionHead.load(config_file, strict=strict) prediction_heads.append(head) ph_output_type.append(head.ph_output_type) model = cls(language_model, prediction_heads, 0.1, ph_output_type, device) if processor: model.connect_heads_with_processor(processor.tasks) return model
[docs] def logits_to_loss_per_head(self, logits, **kwargs): """ Collect losses from each prediction head. :param logits: logits, can vary in shape and type, depending on task. :type logits: object :return: The per sample per prediciton head loss whose first two dimensions have length n_pred_heads, batch_size """ all_losses = [] for head, logits_for_one_head in zip(self.prediction_heads, logits): # check if PredictionHead connected to Processor assert hasattr(head, "label_tensor_name"), \ (f"Label_tensor_names are missing inside the {head.task_name} Prediction Head. Did you connect the model" " with the processor through either 'model.connect_heads_with_processor(processor.tasks)'" " or by passing the processor to the Adaptive Model?") all_losses.append(head.logits_to_loss(logits=logits_for_one_head, **kwargs)) return all_losses
[docs] def logits_to_loss(self, logits, global_step=None, **kwargs): """ Get losses from all prediction heads & reduce to single loss *per sample*. :param logits: logits, can vary in shape and type, depending on task :type logits: object :param global_step: number of current training step :type global_step: int :param kwargs: placeholder for passing generic parameters. Note: Contains the batch (as dict of tensors), when called from Trainer.train(). :type kwargs: object :return loss: torch.tensor that is the per sample loss (len: batch_size) """ all_losses = self.logits_to_loss_per_head(logits, **kwargs) # This aggregates the loss per sample across multiple prediction heads # Default is sum(), but you can configure any fn that takes [Tensor, Tensor ...] and returns [Tensor] loss = self.loss_aggregation_fn(all_losses, global_step=global_step, batch=kwargs) return loss
[docs] def prepare_labels(self, **kwargs): """ Label conversion to original label space, per prediction head. :param label_maps: dictionary for mapping ids to label strings :type label_maps: dict[int:str] :return: labels in the right format """ all_labels = [] # for head, label_map_one_head in zip(self.prediction_heads): # labels = head.prepare_labels(label_map=label_map_one_head, **kwargs) # all_labels.append(labels) for head in self.prediction_heads: labels = head.prepare_labels(**kwargs) all_labels.append(labels) return all_labels
[docs] def forward(self, **kwargs): """ Push data through the whole model and returns logits. The data will propagate through the language model and each of the attached prediction heads. :param kwargs: Holds all arguments that need to be passed to the language model and prediction head(s). :return: all logits as torch.tensor or multiple tensors. """ # Run forward pass of language model sequence_output, pooled_output = self.forward_lm(**kwargs) # Run forward pass of (multiple) prediction heads using the output from above all_logits = [] if len(self.prediction_heads) > 0: for head, lm_out in zip(self.prediction_heads, self.lm_output_types): # Choose relevant vectors from LM as output and perform dropout if lm_out == "per_token": output = self.dropout(sequence_output) elif lm_out == "per_sequence" or lm_out == "per_sequence_continuous": output = self.dropout(pooled_output) elif ( lm_out == "per_token_squad" ): # we need a per_token_squad because of variable metric computation later on... output = self.dropout(sequence_output) else: raise ValueError( "Unknown extraction strategy from language model: {}".format(lm_out) ) # Do the actual forward pass of a single head all_logits.append(head(output)) else: # just return LM output (e.g. useful for extracting embeddings at inference time) all_logits.append((sequence_output, pooled_output)) return all_logits
[docs] def forward_lm(self, **kwargs): """ Forward pass for the language model :param kwargs: :return: """ # Check if we have to extract from a special layer of the LM (default = last layer) try: extraction_layer = self.language_model.extraction_layer except: extraction_layer = -1 # Run forward pass of language model if extraction_layer == -1: sequence_output, pooled_output = self.language_model(**kwargs, return_dict=False, output_all_encoded_layers=False) else: # get output from an earlier layer self.language_model.enable_hidden_states_output() sequence_output, pooled_output, all_hidden_states = self.language_model(**kwargs, return_dict=False) sequence_output = all_hidden_states[extraction_layer] pooled_output = None #not available in earlier layers self.language_model.disable_hidden_states_output() return sequence_output, pooled_output
[docs] def log_params(self): """ Logs paramteres to generic logger MlLogger """ params = { "lm_type": self.language_model.__class__.__name__, "lm_name": self.language_model.name, "prediction_heads": ",".join( [head.__class__.__name__ for head in self.prediction_heads] ), "lm_output_types": ",".join(self.lm_output_types), } try: MlLogger.log_params(params) except Exception as e: logger.warning(f"ML logging didn't work: {e}")
[docs] def verify_vocab_size(self, vocab_size): """ Verifies that the model fits to the tokenizer vocabulary. They could diverge in case of custom vocabulary added via tokenizer.add_tokens()""" model_vocab_len = self.language_model.model.resize_token_embeddings(new_num_tokens=None).num_embeddings msg = f"Vocab size of tokenizer {vocab_size} doesn't match with model {model_vocab_len}. " \ "If you added a custom vocabulary to the tokenizer, " \ "make sure to supply 'n_added_tokens' to LanguageModel.load() and BertStyleLM.load()" assert vocab_size == model_vocab_len, msg for head in self.prediction_heads: if head.model_type == "language_modelling": ph_decoder_len = head.decoder.weight.shape[0] assert vocab_size == ph_decoder_len, msg
[docs] def get_language(self): return self.language_model.language
[docs] def convert_to_transformers(self): """ Convert an adaptive model to huggingface's transformers format. Returns a list containing one model for each prediction head. :return: List of huggingface transformers models. """ return conv.Converter.convert_to_transformers(self)
[docs] @classmethod def convert_from_transformers(cls, model_name_or_path, device, revision=None, task_type=None, processor=None): """ Load a (downstream) model from huggingface's transformers format. Use cases: - continue training in FARM (e.g. take a squad QA model and fine-tune on your own data) - compare models without switching frameworks - use model directly for inference :param model_name_or_path: local path of a saved model or name of a public one. Exemplary public names: - distilbert-base-uncased-distilled-squad - deepset/bert-large-uncased-whole-word-masking-squad2 See https://huggingface.co/models for full list :param revision: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :type revision: str :param device: "cpu" or "cuda" :param task_type: One of : - 'question_answering' - 'text_classification' - 'embeddings' More tasks coming soon ... :param processor: populates prediction head with information coming from tasks :type processor: Processor :return: AdaptiveModel """ return conv.Converter.convert_from_transformers(model_name_or_path, revision=revision, device=device, task_type=task_type, processor=processor)
[docs] @classmethod def convert_to_onnx(cls, model_name, output_path, task_type, convert_to_float16=False, quantize=False, opset_version=11): """ Convert a PyTorch model from transformers hub to an ONNX Model. :param model_name: transformers model name :type model_name: str :param output_path: output Path to write the converted to :type output_path: Path :param task_type: Type of task for the model. Available options: "embeddings", "question_answering", "text_classification", "ner". :param convert_to_float16: By default, the model use float32 precision. With half precision of flaot16, inference should be faster on Nvidia GPUs with Tensor core like T4 or V100. On older GPUs, float32 might be more performant. :type convert_to_float16: bool :param quantize: convert floating point number to integers :type quantize: bool :param opset_version: ONNX opset version :type opset_version: int :return: """ language_model_class = LanguageModel.get_language_model_class(model_name) if language_model_class not in ["Bert", "Roberta", "XLMRoberta"]: raise Exception("The current ONNX conversion only support 'BERT', 'RoBERTa', and 'XLMRoberta' models.") task_type_to_pipeline_map = { "question_answering": "question-answering", "embeddings": "feature-extraction", "ner": "ner" } convert( pipeline_name=task_type_to_pipeline_map[task_type], framework="pt", model=model_name, output=output_path/"model.onnx", opset=opset_version, use_external_format=True if language_model_class is "XLMRoberta" else False ) # save processor & model config files that are needed when loading the model with the FARM Inferencer processor = Processor.convert_from_transformers( tokenizer_name_or_path=model_name, task_type=task_type, max_seq_len=256, doc_stride=128, use_fast=True ) processor.save(output_path) model = AdaptiveModel.convert_from_transformers(model_name, device="cpu", task_type=task_type) model.save(output_path) os.remove(output_path / "language_model.bin") # remove the actual PyTorch model(only configs are required) onnx_model_config = { "task_type": task_type, "onnx_opset_version": opset_version, "language_model_class": language_model_class, "language": model.language_model.language } with open(output_path / "onnx_model_config.json", "w") as f: json.dump(onnx_model_config, f) if convert_to_float16: from onnxruntime_tools import optimizer config = AutoConfig.from_pretrained(model_name) optimized_model = optimizer.optimize_model( input=str(output_path/"model.onnx"), model_type='bert', num_heads=config.num_hidden_layers, hidden_size=config.hidden_size ) optimized_model.convert_model_float32_to_float16() optimized_model.save_model_to_file("model.onnx") if quantize: quantize_model(output_path/"model.onnx")
[docs]class ONNXAdaptiveModel(BaseAdaptiveModel): """ Implementation of ONNX Runtime for Inference of ONNX Models. Existing PyTorch based FARM AdaptiveModel can be converted to ONNX format using AdaptiveModel.convert_to_onnx(). The conversion is currently only implemented for Question Answering Models. For inference, this class is compatible with the FARM Inferencer. """
[docs] def __init__(self, onnx_session, language_model_class, language, prediction_heads, device): if str(device) == "cuda" and onnxruntime.get_device() != "GPU": raise Exception(f"Device {device} not available for Inference. For CPU, run pip install onnxruntime and" f"for GPU run pip install onnxruntime-gpu") self.onnx_session = onnx_session self.language_model_class = language_model_class self.language = language self.prediction_heads = prediction_heads self.device = device
[docs] @classmethod def load(cls, load_dir, device, **kwargs): load_dir = Path(load_dir) import onnxruntime sess_options = onnxruntime.SessionOptions() # Set graph optimization level to ORT_ENABLE_EXTENDED to enable bert optimization. sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED # Use OpenMP optimizations. Only useful for CPU, has little impact for GPUs. sess_options.intra_op_num_threads = multiprocessing.cpu_count() onnx_session = onnxruntime.InferenceSession(str(load_dir / "model.onnx"), sess_options) # Prediction heads _, ph_config_files = cls._get_prediction_head_files(load_dir, strict=False) prediction_heads = [] ph_output_type = [] for config_file in ph_config_files: # ONNX Model doesn't need have a separate neural network for PredictionHead. It only uses the # instance methods of PredictionHead class, so, we load with the load_weights param as False. head = PredictionHead.load(config_file, load_weights=False) prediction_heads.append(head) ph_output_type.append(head.ph_output_type) with open(load_dir/"onnx_model_config.json") as f: model_config = json.load(f) language_model_class = model_config["language_model_class"] language = model_config["language"] return cls(onnx_session, language_model_class, language, prediction_heads, device)
[docs] def forward(self, **kwargs): """ Perform forward pass on the model and return the logits. :param kwargs: all arguments that needs to be passed on to the model :return: all logits as torch.tensor or multiple tensors. """ with torch.no_grad(): if self.language_model_class == "Bert": input_to_onnx = { 'input_ids': numpy.ascontiguousarray(kwargs['input_ids'].cpu().numpy()), 'attention_mask': numpy.ascontiguousarray(kwargs['padding_mask'].cpu().numpy()), 'token_type_ids': numpy.ascontiguousarray(kwargs['segment_ids'].cpu().numpy()), } elif self.language_model_class in ["Roberta", "XLMRoberta"]: input_to_onnx = { 'input_ids': numpy.ascontiguousarray(kwargs['input_ids'].cpu().numpy()), 'attention_mask': numpy.ascontiguousarray(kwargs['padding_mask'].cpu().numpy()) } res = self.onnx_session.run(None, input_to_onnx) res = numpy.stack(res).transpose(1, 2, 0) logits = [torch.Tensor(res).to(self.device)] return logits
[docs] def eval(self): """ Stub to make ONNXAdaptiveModel compatible with the PyTorch AdaptiveModel. """ return True
[docs] def get_language(self): """ Get the language(s) the model was trained for. :return: str """ return self.language
[docs]class ONNXWrapper(AdaptiveModel): """ Wrapper Class for converting PyTorch models to ONNX. As of torch v1.4.0, torch.onnx.export only support passing positional arguments to the forward pass of the model. However, the AdaptiveModel's forward takes keyword arguments. This class circumvents the issue by converting positional arguments to keyword arguments. """
[docs] @classmethod def load_from_adaptive_model(cls, adaptive_model): model = copy.deepcopy(adaptive_model) model.__class__ = ONNXWrapper return model
[docs] def forward(self, *batch): return super().forward(input_ids=batch[0], padding_mask=batch[1], segment_ids=batch[2])