Researchers from Meta AI have developed a unified neural network training method suitable for working with images, sound and text. It uses two identical neural networks – a teacher and a student. The teacher receives full inputs, and the student receives partial inputs and learns to predict the internal representation of the full version of this data in the teacher model. Models trained in this way were better or comparable to models trained with methods specific to the same data type. The article and a brief description of the method are published on the Meta AI website.
In recent years, machine learning researchers have increasingly used unsupervised or self-learning learning. It is especially useful when it is necessary to use huge amounts of data or train a model on a task for which there are no sufficient datasets, for example, to train a language model in a rare language. One of the most striking examples of this type of training is GPT-3, which was trained on 570 gigabytes of texts. However, self-learning developments tend to focus on one modality (one type of data) and learning methods often cannot be directly applied to other modalities.
A Meta AI research team led by Michael Auli has developed a learning method suitable for different types of data. The developers used a standard transformer as a training model for all three modalities and added a specific encoder for one or another data type to its input. During training, two neural network models are used (teacher and student), but in fact they are identical to each other and differ in that the weights of the teacher neural network are slightly lagging and smoothed weights of the student neural network (they are set as an exponential moving average of the student weights) .
The training takes place in the following way. First, the data (image, text or sound) in its entirety is sent to the encoder specific to this data type, and then from the encoder to the teacher neural network. It creates an internal representation of this data. The student model is then fed the same data, but partially hidden. For example, in the images, the authors cut out 60 percent. The task of the student is to use these partial data to predict the internal representation that the teacher model has created from the full data. Since the neural network learns to predict exactly the internal representation, this learning method is not tied to the data type (in this case, an encoder specific to the data type is required), explains N+1.
It is important to understand that we are not talking about training a single model that works with different modalities, but about a single method of training a model for a particular modality. Thus, after training, a model that works with images or with text or with sound, and not with all three types of data at the same time is obtained. Last year, researchers from DeepMind introduced the Perceiver IO neural network, which can work with several types of data at once and does not use specific encoders for them, but at the same time learns according to the principle of learning with a teacher. The authors of the new work note the importance of this model and consider it promising to combine the approaches from both works to create a self-learning multimodal architecture.
The developers tested the transformer model trained by the new method on three tasks. To test the work with images, they trained two network options (ViT-B and ViT-L) on the ImageNet-1K dataset, and tested them on the image classification task from the validating set (top-1). The speech model was trained on 960 hours of speech from the Librispeech dataset and tested against the word error rate (WER). And the model for working with text was trained on the Books Corpus dataset and English Wikipedia data, and tested on the GLUE standard for NLP benchmark.
NIX Solutions notes that testing showed that the new method allows to obtain better results than the previous analogs in the problem of image classification and speech recognition, and also performs slightly worse, but comparable to analogs in text processing tasks.