ProteinNPT (Fig. 1) is a semi-supervised conditional pseudo-generative model that learns a joint representation of protein sequences and associated property labels. The model takes as input both the primary structure representation of the proteins along with the corresponding labels for the property of interest. Let be the full training dataset where are protein sequences (with N the total number of labeled protein sequences and L the sequence length), and the corresponding property labels (where is the number of distinct such labels, including true targets and auxiliary labels, as discussed in § 3.2). Depending on whether we are at training or inference time, we sample a batch of points and mask different parts of this combined input as per the procedure described later in this section. We separately embed protein sequences and labels, concatenate the resulting sequence and label embeddings (each of dimension ) into a single tensor , which we then feed into several ProteinNPT layers. A ProteinNPT layer (Fig. 1 - right) learns joint representation of protein sequences and labels by applying successively self-attention between residues and labels for a given sequence (row-attention), self-attention across sequences in the input batch at a given position (column-attention), and a feedforward layer. Each of these transforms is preceded by a LayerNorm operator and we add residual connections to the output of each step. For the multi-head row-attention sub-layer, we linearly project embeddings for each labeled sequence for each attention head via the linear embeddings and respectively. Mathematically, we have:
where the concatenation is performed row-wise, mixes outputs from different heads, and we use tied row-attention as defined in Rao et al. [2021] as the attention maps ought to be similar across labeled instances from the same protein family:
We then apply column-attention as follows:
where the concatenation is performed column-wise, mixes outputs from different heads, and the standard self-attention operator . Lastly, the feedforward sub-layer applies a row-wise feed-forward network:
In the final stage, the learned embeddings from the last layer are used to predict both the masked tokens and targets: the embeddings of masked targets are input into a L2-penalized linear projection to predict masked target values, and the embeddings of masked tokens are linearly projected then input into a softmax activation to predict the corresponding original tokens.
(Left) The model takes as input the primary structure of a batch of proteins of length along with the corresponding labels and, optionally, auxiliary labels (for simplicity we consider here). Each input is embedded separately, then all resulting embeddings are concatenated into a single tensor. Several ProteinNPT layers are subsequently applied to learn a representation of the entire batch, which is ultimately used to predict both masked tokens and targets (depicted by question marks). (Right) A ProteinNPT layer alternates between tied row and column attention to learn rich embeddings of the labeled batch.
Do you have any questions about this protocol?
Post your question to gather feedback from the community. We will also invite the authors of this article to respond.