Attention mechanism is to solve the Information Overload problem by weighting more on the importance elementS during the trainning.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
print("Version: TensorFlow", tf.__version__)
Category | type of attention | Example usage | Advantage | Disadvantage |
---|---|---|---|---|
Number of position | hard Xu et al., 2015 |
选取注意力分布中概率最高的输入向量,即最大采样, selects one patch of the image to attend to at a time, i.e. maxpooling, gating, Reinforcement Learning |
While less expensive at inference time | hard to train due to non-differentiable and requires more complicated techniques |
Number of position | soft/global MT Luong et. al, 2015 |
by gradient descent, weights are placed “softly" on value vector i.e. word vector, image pixel, channel | non-local dependency, long-term dependency |
it has to attend to value on the source side for each key, which is High computational expensive |
Number of position | local Xu et al.2015 MT Luong et. al, 2015 |
balance the trade of between soft/global and hard attention by manually restrict on receptive field | avoiding the expensive computation, easier to train approach | |
Number of sequences | distinctive Bahdanau et al., 2014 | key and value are not the same | ||
Number of sequences | co-attention J Lu · 2016 |
fuse attention with Cartesian product on attention map of query and key | feature crosses | |
Number of sequences | self-attention A Vaswani, 2017 |
soft adressing that query and key are identical | identical to soft attention | |
abstraction levels | single-level | attention weights are computed only for the original input sequence | ||
abstraction levels | multi-level Zhao and Zhang, 2018 Z Yang et al., 2016 |
multiple attentions are applied in a sequential manner on multiple levels . i.e. word to sentense to article |
to understand different abstraction level of the input sequence from low to high | |
representation | multi-representation Kiela et al. , 2018 |
use different word embeddings for the same input sentence | capture different aspects of the input through multiple feature representations i.e. lexical, syntactic, visual and genre information |
|
representation | multi-dimensional Lin et al. , 2017 |
determine the relevance of each dimension of the input embedding vector | more effective sentence embedding representation and language understanding problem. |
Reference:
The most often use attention today is soft attention, which can use soft-addressing as the analogy.
Addressing is a processing that sending query to find a key in a source and extract the value corresponding to the key. Hard addressing always has one query match one key, becuase the matching query and key are exactly same, whereas soft addressing allow one query match multiple key and result multiple values.
The calculation of attention include three step:
$z_i = similarity(q,k)$
Similarity | Formula for similarity(q,k) | Computational complexity |
---|---|---|
Bahdanau additive style tf.reduce_sum(tf.tanh(query + key), axis=-1) |
$k^{T}tanh(w^{T}q+b)$ | |
Dot product (luong multiplicative style) tf.matmul(query, key, transpose_b=True) |
$k^T q$ | 点积注意力在实践中更快速且参数空间更高效,因为它能通过高度优化的矩阵乘法库并行地计算。 点积模型可以更好的利用矩阵乘积,计算效率更高。 |
Bilearning | $k^T W q$ $= k^T (L^T H)q = (k L)^T (R q)$ |
相比点积模型,双线性模型在计算相似度时引入了非对称性 |
Cosine similarity | $\frac{k^T q}{\left|k\right| \left|q\right|}$ | |
Scaled-Dot Attention | $\frac{k^T q}{\sqrt{d}}$ | 当输入向量的维度较高时,点积模型的值通常由较大的方差,从而导致 Softmax 函数的梯度会比较小。而缩放点积模型可以较好的解决这个问题。 |
Self-attention (multihead-attention when head =1 ) |
$\frac{(W_k k)^T(W_q q)}{\sqrt{d}}$ |
softmax activation amplifies the important element (Value) of the input vector, resulting the importance weights $\alpha_i$ of each element:
$\alpha_i = \sigma(\vec{z}_{i}) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$
Where:
Weighted summation aggregate the weighted score to a single attention score:
$attention(q, v) = v_i^*= \sum \alpha_i v_i$
where:
# Query-value attention of shape [batch_size, Tq, filters].
query = tf.keras.Input(shape=[8, 16], name='query')
value = tf.keras.Input(shape=[4, 16], name='value')
# Bahdanau attention
Bahdanau_attention_seq = layers.AdditiveAttention(name = 'Bahdanau_attention_seq')([query, value])
model = keras.Model(
inputs=[query, value],
outputs=[Bahdanau_attention_seq]
)
keras.utils.plot_model(model, show_shapes=True)
# Luong attention
Luong_attention_seq = layers.Attention(name = 'Luong_attention_seq')([query, value])
model = keras.Model(
inputs=[query, value],
outputs=[Luong_attention_seq]
)
keras.utils.plot_model(model, show_shapes=True)
# self attention
self_attention_seq = layers.MultiHeadAttention(num_heads=1, key_dim=2, name = 'MultiHead_attention_seq')(query, value)
model = keras.Model(
inputs=[query, value],
outputs=[self_attention_seq]
)
keras.utils.plot_model(model, show_shapes=True)
# MultiHead attention
MultiHead_attention_seq = layers.MultiHeadAttention(num_heads=2, key_dim=2, name = 'MultiHead_attention_seq')(query, value)
model = keras.Model(
inputs=[query, value],
outputs=[MultiHead_attention_seq]
)
keras.utils.plot_model(model, show_shapes=True)
CBAM module is a widely used attention module in computer vision that apply several types of attention.
Reference:
CBAM: Convolutional Block Attention Module [https://arxiv.org/abs/1807.06521]
def cbam_module(inputs,reduction_ratio=0.5,name=""):
batch_size,channel_num=inputs.get_shape().as_list()[0],inputs.get_shape().as_list()[3]
# Channel attention module
maxpool_channel=tf.reduce_max(tf.reduce_max(inputs,axis=1,keepdims=True),axis=2,keepdims=True)
avgpool_channel=tf.reduce_mean(tf.reduce_mean(inputs,axis=1,keepdims=True),axis=2,keepdims=True)
maxpool_channel = layers.Flatten()(maxpool_channel)
avgpool_channel = layers.Flatten()(avgpool_channel)
mlp_1_max=layers.Dense(units=int(channel_num*reduction_ratio),name="mlp_1_maxpool",activation=tf.nn.relu)(maxpool_channel)
mlp_2_max=layers.Dense(units=channel_num,name="mlp_2_maxpool")(mlp_1_max)
mlp_2_max=tf.expand_dims(tf.expand_dims(mlp_2_max, 1), 1)
mlp_1_avg=layers.Dense(units=int(channel_num*reduction_ratio),name="mlp_1_avgpool",activation=tf.nn.relu)(avgpool_channel)
mlp_2_avg=layers.Dense(units=channel_num,name="mlp_2_avgpool")(mlp_1_avg)
mlp_2_avg=tf.expand_dims(tf.expand_dims(mlp_2_avg, 1), 1)
channel_attention=layers.Add(name='sum_channel_attention')([mlp_2_max,mlp_2_avg])
channel_attention=tf.math.sigmoid(channel_attention)
channel_refined_feature=layers.Multiply(name="channel_attention")([inputs,channel_attention])
# Spatial attention module
maxpool_spatial=tf.math.reduce_max(inputs,axis=3,keepdims=True)
avgpool_spatial=tf.math.reduce_mean(inputs,axis=3,keepdims=True)
max_avg_pool_spatial=layers.concatenate([maxpool_spatial,avgpool_spatial],axis=3,name="concat_layer")
conv_layer=layers.Conv2D(filters=1, kernel_size=(7, 7), padding="same", activation=None)(max_avg_pool_spatial)
spatial_attention=tf.math.sigmoid(conv_layer)
refined_feature=layers.Multiply(name="spatial_attention")([channel_refined_feature,spatial_attention])
return refined_feature
input_img = keras.Input(shape=[64, 64, 3], name='image input')
refined_feature = cbam_module(input_img)
model = keras.Model(inputs = [input_img], outputs = [refined_feature])
keras.utils.plot_model(model, show_shapes=True)