Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch

Overview

Rotary Embeddings - Pytorch

A standalone library for adding rotary embeddings to transformers in Pytorch, following its success as relative positional encoding. Specifically it will make rotating information into any axis of a tensor easy and efficient, whether they be fixed positional or learned. This library will give you state of the art results for positional embedding, at little costs.

My gut also tells me there is something more to rotations that can be exploited in artificial neural networks.

Install

$ pip install rotary-embedding-torch

Usage

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding

# instantiate the positional embedding in your transformer and pass to all your attention layers

pos_emb = RotaryEmbedding(dim = 32)

# generate the rotations

freqs = pos_emb(torch.arange(1024), cache_key = 1024) # cache with a key that is the sequence length, so that it does not need to recompute

# mock queries and keys

q = torch.randn(1, 1024, 64) # queries - (batch, seq len, dimension of head)
k = torch.randn(1, 1024, 64) # keys

# apply the rotations to your queries and keys after the heads have been split out, but prior to the dot product and subsequent softmax (attention)

freqs = freqs[None, ...] # unsqueeze for batch dimension
q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

# then do your attention with your queries (q) and keys (k)

If you do all the steps above correctly, you should see a dramatic improvement during training

Axial Rotary Embeddings

For easy use of 2d axial relative positional embedding, ie. vision transformers

import torch
from rotary_embedding_torch import apply_rotary_emb, RotaryEmbedding, broadcat

pos_emb = RotaryEmbedding(
    dim = 32,
    freqs_for = 'pixel'
)

# queries and keys for frequencies to be rotated into

q = torch.randn(1, 256, 256, 64)
k = torch.randn(1, 256, 256, 64)

# get frequencies for each axial
# -1 to 1 has been shown to be a good choice for images and audio

freqs_h = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)
freqs_w = pos_emb(torch.linspace(-1, 1, steps = 256), cache_key = 256)

# concat the frequencies along each axial
# broadcat function makes this easy without a bunch of expands

freqs = broadcat((freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim = -1)

# rotate in frequencies

q = apply_rotary_emb(freqs, q)
k = apply_rotary_emb(freqs, k)

Learned Rotations

For injecting learned rotations into a network. Experiments pending

Update: doesn't seem to do anything -_-, will keep trying...

import torch
from torch import nn
from rotary_embedding_torch import apply_learned_rotations

x = torch.randn(1, 1024, 512)

# you can only rotate in (dim // 2) values
# ex. for 512, you can only rotate in 256 values

# say you have two sets of learned rotations of 128 values each

rots1 = nn.Linear(512, 128)(x)
rots2 = nn.Linear(512, 128)(x)

# you rotate in 256 (128 x 2) at first

x = apply_learned_rotations(rots1, x, start_index = 0)

# then you start at index 256 and rotate in the last (128 x 2)

x = apply_learned_rotations(rots2, x, start_index = 256)

# you could also concat the rotations together and pass it in all at once

rots = torch.cat((rots1, rots2), dim = -1)

x = apply_learned_rotations(rots, x)

Citations

@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.
Joint detection and tracking model named DEFT, or ``Detection Embeddings for Tracking.

DEFT: Detection Embeddings for Tracking DEFT: Detection Embeddings for Tracking, Mohamed Chaabane, Peter Zhang, J. Ross Beveridge, Stephen O'Hara

Learning embeddings for classification, retrieval and ranking.
Learning embeddings for classification, retrieval and ranking.

StarSpace StarSpace is a general-purpose neural model for efficient learning of entity embeddings for solving a wide variety of problems: Learning wor

Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation
Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation

Unseen Object Clustering: Learning RGB-D Feature Embeddings for Unseen Object Instance Segmentation Introduction In this work, we propose a new method

Improving XGBoost survival analysis with embeddings and debiased estimators
Improving XGBoost survival analysis with embeddings and debiased estimators

xgbse: XGBoost Survival Embeddings "There are two cultures in the use of statistical modeling to reach conclusions from data

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

Reliable probability face embeddings
Reliable probability face embeddings

ProbFace, arxiv This is a demo code of training and testing [ProbFace] using Tensorflow. ProbFace is a reliable Probabilistic Face Embeddging (PFE) me

 UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus
UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus

UmlsBERT: Clinical Domain Knowledge Augmentation of Contextual Embeddings Using the Unified Medical Language System Metathesaurus General info This is

🤖 A Python library for learning and evaluating knowledge graph embeddings
🤖 A Python library for learning and evaluating knowledge graph embeddings

PyKEEN PyKEEN (Python KnowlEdge EmbeddiNgs) is a Python package designed to train and evaluate knowledge graph embedding models (incorporating multi-m

Large scale embeddings on a single machine.

Marius Marius is a system under active development for training embeddings for large-scale graphs on a single machine. Training on large scale graphs

Comments
  • Custom position offset when rotating queries or keys

    Custom position offset when rotating queries or keys

    This library seems to assume that queries and keys are left-aligned position-wise e.g.

    q = [p_0, p_1, p_2]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    where p_i are corresponding positions. This is enforced by starting the sequence of positions always from 0 with torch.arange(seq_len) here. Applications like Perceiver AR, however, require a position-wise right-alignment e.g.

    q =           [p_2, p_3, p_4]
    k = [p_0, p_1, p_2, p_3, p_4]
    

    This pull requests allows to specify a start position for queries and or keys to enable alignments other than left-alignments. For example

    import torch
    from rotary_embedding_torch.rotary_embedding_torch import RotaryEmbedding
    
    rot = RotaryEmbedding(dim=32)
    
    q = torch.ones(1, 8, 4, 32)
    k = torch.ones(1, 8, 6, 32)
    
    q = q / torch.norm(q, dim=-1, keepdim=True)
    k = k / torch.norm(k, dim=-1, keepdim=True)
    
    q_rot = rot.rotate_queries_or_keys(q, start_pos=k.shape[2] - q.shape[2])
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    prints the following relative position embedding

    tensor([[0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581],
            [0.7288, 0.7670, 0.8581, 0.9571, 1.0000, 0.9571],
            [0.7361, 0.7288, 0.7670, 0.8581, 0.9571, 1.0000]])
    

    (diagonal of 1s right-aligned) whereas the default behavior

    ...
    
    q_rot = rot.rotate_queries_or_keys(q)
    k_rot = rot.rotate_queries_or_keys(k)
    
    attn = torch.einsum("b h i c, b h j c -> b h i j", q_rot, k_rot)
    print(attn[0, 0])
    

    would print

    tensor([[1.0000, 0.9571, 0.8581, 0.7670, 0.7288, 0.7361],
            [0.9571, 1.0000, 0.9571, 0.8581, 0.7670, 0.7288],
            [0.8581, 0.9571, 1.0000, 0.9571, 0.8581, 0.7670],
            [0.7670, 0.8581, 0.9571, 1.0000, 0.9571, 0.8581]])
    

    (diagonal of 1s left-aligned).

    opened by krasserm 1
  • about axial rotary embeddings

    about axial rotary embeddings

    Hi, Thank you for sharing this code with us. However, I was confused with the axial rotary embeddings in rotary_embedding_torch.py file. " elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi " Where does this formula come from?What parameter is max_freqs?Why the freqs is not " 1/(10000^(2i/d))"?

    Thank you again.

    opened by raindrop313 0
Owner
Phil Wang
Working with Attention
Phil Wang
2.86% and 15.85% on CIFAR-10 and CIFAR-100

Shake-Shake regularization This repository contains the code for the paper Shake-Shake regularization. This arxiv paper is an extension of Shake-Shake

Xavier Gastaldi 294 Nov 22, 2022
MoveNetを用いたPythonでの姿勢推定のデモ

MoveNet-Python-Example MoveNetのPythonでの動作サンプルです。 ONNXに変換したモデルも同梱しています。変換自体を試したい方はMoveNet_tf2onnx.ipynbを使用ください。 2021/08/24時点でTensorFlow Hubで提供されている以下モデ

KazuhitoTakahashi 38 Dec 17, 2022
Tensorflow implementation of Human-Level Control through Deep Reinforcement Learning

Human-Level Control through Deep Reinforcement Learning Tensorflow implementation of Human-Level Control through Deep Reinforcement Learning. This imp

Devsisters Corp. 2.4k Dec 26, 2022
Unsupervised captioning - Code for Unsupervised Image Captioning

Unsupervised Image Captioning by Yang Feng, Lin Ma, Wei Liu, and Jiebo Luo Introduction Most image captioning models are trained using paired image-se

Yang Feng 207 Dec 24, 2022
Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch

Omninet - Pytorch Implementation of OmniNet, Omnidirectional Representations from Transformers, in Pytorch. The authors propose that we should be atte

Phil Wang 48 Nov 21, 2022
PyTorch module to use OpenFace's nn4.small2.v1.t7 model

OpenFace for Pytorch Disclaimer: This codes require the input face-images that are aligned and cropped in the same way of the original OpenFace. * I m

Pete Tae-hoon Kim 176 Dec 12, 2022
LSUN Dataset Documentation and Demo Code

LSUN Please check LSUN webpage for more information about the dataset. Data Release All the images in one category are stored in one lmdb database fil

Fisher Yu 426 Jan 02, 2023
[NeurIPS-2021] Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data

MosaicKD Code for NeurIPS-21 paper "Mosaicking to Distill: Knowledge Distillation from Out-of-Domain Data" 1. Motivation Natural images share common l

ZJU-VIPA 37 Nov 10, 2022
PyTorch-based framework for Deep Hedging

PFHedge: Deep Hedging in PyTorch PFHedge is a PyTorch-based framework for Deep Hedging. PFHedge Documentation Neural Network Architecture for Efficien

139 Dec 30, 2022
Original Pytorch Implementation of FLAME: Facial Landmark Heatmap Activated Multimodal Gaze Estimation

FLAME Original Pytorch Implementation of FLAME: Facial Landmark Heatmap Activated Multimodal Gaze Estimation, accepted at the 17th IEEE Internation Co

Neelabh Sinha 19 Dec 17, 2022
S2-BNN: Bridging the Gap Between Self-Supervised Real and 1-bit Neural Networks via Guided Distribution Calibration (CVPR 2021)

S2-BNN (Self-supervised Binary Neural Networks Using Distillation Loss) This is the official pytorch implementation of our paper: "S2-BNN: Bridging th

Zhiqiang Shen 52 Dec 24, 2022
TensorFlow implementation of ENet, trained on the Cityscapes dataset.

segmentation TensorFlow implementation of ENet (https://arxiv.org/pdf/1606.02147.pdf) based on the official Torch implementation (https://github.com/e

Fredrik Gustafsson 248 Dec 16, 2022
Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

Official PyTorch Implementation of paper EAN: Event Adaptive Network for Efficient Action Recognition

TianYuan 27 Nov 07, 2022
Framework for abstracting Amiga debuggers and access to AmigaOS libraries and devices.

Framework for abstracting Amiga debuggers. This project provides abstration to control an Amiga remotely using a debugger. The APIs are not yet stable

Roc Vallès 39 Nov 22, 2022
Lecture materials for Cornell CS5785 Applied Machine Learning (Fall 2021)

Applied Machine Learning (Cornell CS5785, Fall 2021) This repo contains executable course notes and slides for the Applied ML course at Cornell and Co

Volodymyr Kuleshov 103 Dec 31, 2022
Neural Surface Maps

Neural Surface Maps Official implementation of Neural Surface Maps - Luca Morreale, Noam Aigerman, Vladimir Kim, Niloy J. Mitra [Paper] [Project Page]

Luca Morreale 49 Dec 13, 2022
PyTea: PyTorch Tensor shape error analyzer

PyTea: PyTorch Tensor Shape Error Analyzer paper project page Requirements node.js = 12.x python = 3.8 z3-solver = 4.8 How to install and use # ins

ROPAS Lab. 240 Jan 02, 2023
Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Project page for the paper Semi-Supervised Raw-to-Raw Mapping 2021.

Mahmoud Afifi 22 Nov 08, 2022
Riemannian Geometry for Molecular Surface Approximation (RGMolSA)

Riemannian Geometry for Molecular Surface Approximation (RGMolSA) Introduction Ligand-based virtual screening aims to reduce the cost and duration of

11 Nov 15, 2022
Offcial repository for the IEEE ICRA 2021 paper Auto-Tuned Sim-to-Real Transfer.

Offcial repository for the IEEE ICRA 2021 paper Auto-Tuned Sim-to-Real Transfer.

47 Jun 30, 2022