A mini lib that implements several useful functions binding to PyTorch in C++.

Overview

Torch-gather

A mini library that implements several useful functions binding to PyTorch in C++.

What does gather do? Why do we need it?

When dealing with sequences, a common way of processing the variable lengths is padding them to the max length, which leads to quite a lot redundancies and waste on computing and memory as sequences length varies. So gather just removes their paddings and makes computation without waste of computation resource.

Install

python setup.py install

Docs

Note that all the input tensors should be on cuda device.

  • gather.gathercat(x_padded:torch.FloatTensor, lx:torch.IntTensor)

    Return a concatence of given padded tensor x_padded according to its lengths lx.

    Input:

    x_padded (torch.float): padded tensor of size (N, L, V), where L=max(lx).

    lx (torch.int): lengths of size (N, ).

    Return:

    x_gather (torch.float): the gathered tensor without paddings of size (lx[0]+lx[1]+...+lx[N-1], V)

    Example:

    >>> import torch
    >>> from gather import gathercat
    >>> lx = torch.randint(3, 20, (5, ), dtype=torch.int32, device='cuda')
    >>> x_padded = torch.randn((5, lx.max(), 64), device='cuda')
    >>> x_padded.size(), lx.size()
    (torch.Size([5, 19, 64]), torch.Size([5]))
    >>> x_gather = gathercat(x_padded, lx)
    >>> x_gather.size()
    torch.Size([81, 64])
    # another example, with V=1
    >>> x_padded = torch.tensor([[1., 2., 3.],[1.,2.,0.]], device='cuda').unsqueeze(2)
    >>> lx = torch.tensor([3,2], dtype=torch.int32, device='cuda')
    >>> x_padded
    tensor([[[1.],
            [2.],
            [3.]],
    
            [[1.],
            [2.],
            [0.]]], device='cuda:0')
    >>> lx
    tensor([3, 2], device='cuda:0', dtype=torch.int32)
    >>> gathercat(x_padded, lx)
    tensor([[1.],
            [2.],
            [3.],
            [1.],
            [2.]], device='cuda:0')

    This function is easy to implement with torch python functions like torch.cat(), however, gathercat() is customized for specified tasks, and more efficient.

  • gather.gathersum(xs:torch.FloatTensor, ys:torch.FloatTensor, lx:torch.IntTensor, ly:torch.IntTensor)

    Return a sequence-matched broadcast sum of given paired gathered tensor xs and ys. For a pair of sequences in xs and ys, say xs_i and ys_i, gathersum() broadcast them so that they can be added up. The broadcast step can be understood as (xs_i.unsqueeze(1)+ys_i.unsqueeze(2)).reshape(-1, V) with python and torch.

    Input:

    xs (torch.float): gathered tensor of size (ST, V), where ST=sum(lx).

    ys (torch.float): gathered tensor of size (SU, V), where SU=sum(ly).

    lx (torch.int): lengths of size (N, ). lx[i] denotes length of the $i_{th}$ sequence in xs.

    ly (torch.int): lengths of size (N, ). ly[i] denotes length of the $i_{th}$ sequence in ys.

    Return:

    gathered_sum (torch.float): the gathered sequence-match sum of size (lx[0]ly[0]+lx[1]ly[1]+...+lx[N-1]ly[N-1], V)

    Example:

    >>> import torch
    >>> from gather import gathersum
    >>> N, T, U, V = 5, 4, 4, 3
    >>> lx = torch.randint(1, T, (N, ), dtype=torch.int32, device='cuda')
    >>> ly = torch.randint(1, U, (N, ), dtype=torch.int32, device='cuda')
    >>> xs = torch.randn((lx.sum(), V), device='cuda')
    >>> ys = torch.randn((ly.sum(), V), device='cuda')
    >>> xs.size(), ys.size(), lx.size(), ly.size()
    (torch.Size([11, 3]), torch.Size([10, 3]), torch.Size([5]), torch.Size([5]))
    >>> gathered_sum = gathersum(xs, ys, lx, ly)
    >>> gathered_sum.size()
    torch.Size([20, 3])
    # let's see how the size 20 comes out
    >>> lx.tolist(), ly.tolist()
    ([2, 2, 1, 3, 3], [3, 1, 3, 1, 2])
    # still unclear? Uh, how about this?
    >>> (lx * ly).sum().item()
    20

    This function seems doing something weird. Please refer to the discussion page for a specific usage example.

Reference

  • PyTorch binding refers to the 1ytic/warp-rnnt

  • For the specific usage of these functions, please refer to this discussion.

Owner
maxwellzh
maxwellzh
Code and dataset for AAAI 2021 paper FixMyPose: Pose Correctional Describing and Retrieval Hyounghun Kim, Abhay Zala, Graham Burri, Mohit Bansal.

FixMyPose / फिक्समाइपोज़ Code and dataset for AAAI 2021 paper "FixMyPose: Pose Correctional Describing and Retrieval" Hyounghun Kim*, Abhay Zala*, Grah

4 Sep 19, 2022
Python library for tracking human heads with FLAME (a 3D morphable head model)

Video Head Tracker 3D tracking library for human heads based on FLAME (a 3D morphable head model). The tracking algorithm is inspired by face2face. It

61 Dec 25, 2022
Retrieval.pytorch - The code we used in [2020 DIGIX]

Retrieval.pytorch - The code we used in [2020 DIGIX]

Guo-Hua Wang 2 Feb 07, 2022
Code for DisCo: Remedy Self-supervised Learning on Lightweight Models with Distilled Contrastive Learning

DisCo: Remedy Self-supervised Learning on Lightweight Models with Distilled Contrastive Learning Pytorch Implementation for DisCo: Remedy Self-supervi

79 Jan 06, 2023
Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction

Unsupervised Feature Loss (UFLoss) for High Fidelity Deep learning (DL)-based reconstruction Official github repository for the paper High Fidelity De

28 Dec 16, 2022
Simultaneous Detection and Segmentation

Simultaneous Detection and Segmentation This is code for the ECCV Paper: Simultaneous Detection and Segmentation Bharath Hariharan, Pablo Arbelaez,

Bharath Hariharan 96 Jul 20, 2022
DGL-TreeSearch and the Gurobi-MWIS interface

Independent Set Benchmarking Suite This repository contains the code for our maximum independent set benchmarking suite as well as our implementations

Maximilian Böther 19 Nov 22, 2022
【ACMMM 2021】DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning

DSANet: Dynamic Segment Aggregation Network for Video-Level Representation Learning (ACMMM 2021) Overview We release the code of the DSANet (Dynamic S

Wenhao Wu 46 Dec 27, 2022
AI-UPV at IberLEF-2021 EXIST task: Sexism Prediction in Spanish and English Tweets Using Monolingual and Multilingual BERT and Ensemble Models

AI-UPV at IberLEF-2021 EXIST task: Sexism Prediction in Spanish and English Tweets Using Monolingual and Multilingual BERT and Ensemble Models Descrip

Angel de Paula 1 Jun 08, 2022
Feup-csr - Repository holding my group's submission to the CSR project competition

CSR Competições de Swarm Robotics Swarm Robotics Competitions This repository holds the files submitted for the CSR project competition. Project group

Nuno Pereira 1 Jan 04, 2022
Research code for CVPR 2021 paper "End-to-End Human Pose and Mesh Reconstruction with Transformers"

MeshTransformer ✨ This is our research code of End-to-End Human Pose and Mesh Reconstruction with Transformers. MEsh TRansfOrmer is a simple yet effec

Microsoft 473 Dec 31, 2022
View model summaries in PyTorch!

torchinfo (formerly torch-summary) Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensor

Tyler Yep 1.5k Jan 05, 2023
ViSD4SA, a Vietnamese Span Detection for Aspect-based sentiment analysis dataset

UIT-ViSD4SA PACLIC 35 General Introduction This repository contains the data of the paper: Span Detection for Vietnamese Aspect-Based Sentiment Analys

Nguyễn Thị Thanh Kim 5 Nov 13, 2022
A collection of scripts I developed for personal and working projects.

A collection of scripts I developed for personal and working projects Table of contents Introduction Repository diagram structure List of scripts pyth

Gianluca Bianco 109 Dec 26, 2022
The code of "Dependency Learning for Legal Judgment Prediction with a Unified Text-to-Text Transformer".

Code data_preprocess.py: preprocess data for Dependent-T5. parameters.py: define parameters of Dependent-T5. train_tools.py: traning and evaluation co

1 Apr 21, 2022
Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.

NuPIC Numenta Platform for Intelligent Computing The Numenta Platform for Intelligent Computing (NuPIC) is a machine intelligence platform that implem

Numenta 6.3k Dec 30, 2022
Model of an AI powered sign language interpreter.

TEXT AND SPEECH TO SIGN LANGUAGE. A web application which takes in text or live audio speech recording as input, converts and displays the relevant Si

Mark Gatere 4 Mar 30, 2022
Code for KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs

KiloNeRF: Speeding up Neural Radiance Fields with Thousands of Tiny MLPs Check out the paper on arXiv: https://arxiv.org/abs/2103.13744 This repo cont

Christian Reiser 373 Dec 20, 2022
Exemplo de implementação do padrão circuit breaker em python

fast-circuit-breaker Circuit breakers existem para permitir que uma parte do seu sistema falhe sem destruir todo seu ecossistema de serviços. Michael

James G Silva 17 Nov 10, 2022
Syntax-Aware Action Targeting for Video Captioning

Syntax-Aware Action Targeting for Video Captioning Code for SAAT from "Syntax-Aware Action Targeting for Video Captioning" (Accepted to CVPR 2020). Th

59 Oct 13, 2022