JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library

Overview

JAX bindings to FINUFFT

This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library. Take a look at the FINUFFT docs for all the necessary definitions, conventions, and more information about the algorithms and their implementation. This package uses a low-level interface to directly expose the FINUFFT library to JAX's XLA backend, as well as implementing differentiation rules for the transforms.

Included features

This library is currently CPU-only, but GPU support is in the works using the cuFINUFFT library.

Type 1 and 2 transforms are supported in 1-, 2-, and 3-dimensions. All of these functions support forward, reverse, and higher-order differentiation, as well as batching using vmap.

Installation

For now, only a source build is supported.

For building, you should only need a recent version of Python (>3.6) and FFTW. At runtime, you'll need numpy, scipy, and jax. To set up such an environment, you can use conda (but you're welcome to use whatever workflow works for you!):

conda create -n jax-finufft -c conda-forge python=3.9 numpy scipy fftw
python -m pip install "jax[cpu]"

Then you can install from source using (don't forget the --recursive flag because FINUFFT is included as a submodule):

git clone --recursive https://github.com/dfm/jax-finufft
cd jax-finufft
python -m pip install .

Usage

This library provides two high-level functions (and these should be all that you generally need to interact with): nufft1 and nufft2 (for the two "types" of transforms). If you're already familiar with the Python interface to FINUFFT, please note that the function signatures here are different!

For example, here's how you can do a 1-dimensional type 1 transform:

import numpy as np
from jax_finufft import nufft1

M = 100000
N = 200000

x = 2 * np.pi * np.random.uniform(size=M)
c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
f = nufft1(N, c, x, eps=1e-6, iflag=1)

Noting that the eps and iflag are optional, and that (for good reason, I promise!) the order of the positional arguments is reversed from the finufft Python package.

The syntax for a 2-, or 3-dimensional transform is:

f = nufft1((Nx, Ny), c, x, y)  # 2D
f = nufft1((Nx, Ny, Nz), c, x, y, z)  # 3D

The syntax for a type 2 transform is (also allowing optional iflag and eps parameters):

c = nufft2(f, x)  # 1D
c = nufft2(f, x, y)  # 2D
c = nufft2(f, x, y, z)  # 3D

Similar libraries

  • finufft: The "official" Python bindings to FINUFFT. A good choice if you're not already using JAX and if you don't need to differentiate through your transform.
  • mrphys/tensorflow-nufft: TensorFlow bindings for FINUFFT and cuFINUFFT.

License & attribution

This package, developed by Dan Foreman-Mackey is licensed under the Apache License, Version 2.0, with the following copyright:

Copyright 2021 The Simons Foundation, Inc.

If you use this software, please cite the primary references listed on the FINUFFT docs.

Comments
  • batching issue

    batching issue

    Hi,

    I get the following error when I try to batch nufft2 in Jax.

    process_primitive(self, primitive, tracers, params) 161 frame = self.get_frame(vals_in, dims_in) 162 batched_primitive = self.get_primitive_batcher(primitive, frame) --> 163 val_out, dim_out = batched_primitive(vals_in, dims_in, **params) 164 if primitive.multiple_results: 165 return map(partial(BatchTracer, self), val_out, dim_out)

    TypeError: batch() got an unexpected keyword argument 'output_shape'

    it seems like this is caused by

    nufft2(source, iflag, eps, *points) 57 58 return jnp.reshape( ---> 59 nufft2_p.bind(source, *points, output_shape=None, iflag=iflag, eps=eps), 60 expected_output_shape, 61 )

    Is there something I am doing wrong?

    Thanks for your help!

    opened by samaktbo 13
  • problem batching functions that have multiple arguments

    problem batching functions that have multiple arguments

    Hi,

    I am posting this here to give a clearer description of the problem that I am having.

    When I run the following snippet, I would like to have A be a 4 by 100 array whose I-th row is the output of linear_func(q, X[I,:]).

    import numpy as np
    import jax.numpy as jnp
    from jax import vmap
    from jax_finufft import nufft2
    
    rng = np.random.default_rng(seed=314)
    
    d=10 
    L_tilde = 10
    L = 100
    
    qr = rng.standard_normal((d, L_tilde+1))
    qi = rng.standard_normal((d, L_tilde+1))
    q = jnp.array(qr + 1j * qi)
    X = jnp.array(rng.uniform(low=0.0, high=1.0, size=(4, L)))
    
    def linear_func(q, x):
      v = jnp.ones(shape=(1, L_tilde))
    
      return jnp.matmul(v, nufft2(q, x, eps=1e-6, iflag=-1))
    
    batched = vmap(linear_func, in_axes=(None, 0), out_axes=0)
    
    A = batched(q, X)
    

    However, when I run the snippet I get the error posted below. You had said last time that there could be a work around for unbatched arguments but I could not figure it out.

    Here is the error:

    UnfilteredStackTrace                      Traceback (most recent call last)
    <ipython-input-25-c23bc4fa7eb4> in <module>()
    ----> 1 test = batched(q, X)
    
    30 frames
    UnfilteredStackTrace: TypeError: '<' not supported between instances of 'NoneType' and 'int'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax_finufft/ops.py in <genexpr>(.0)
        281     else:
        282         mx = args[0].ndim - ndim - 1
    --> 283     assert all(a < mx for a in axes)
        284     assert all(a == axes[0] for a in axes[1:])
        285     return prim.bind(*args, **kwargs), axes[0]
    
    TypeError: '<' not supported between instances of 'NoneType' and 'int'
    

    I hope this gives a clearer picture than what I had last time. Thanks so much for your help!

    opened by samaktbo 5
  • Installation issue

    Installation issue

    I am having trouble installing jax-finuff even with the instructions on the home page. The installation fails with

    ERROR: Failed building wheel for jax-finufft Failed to build jax-finufft ERROR: Could not build wheels for jax-finufft, which is required to install pyproject.toml-based projects

    opened by samaktbo 5
  • Error when differentiating nufft1 with respect to points only

    Error when differentiating nufft1 with respect to points only

    Hi Dan,

    First, thank you for releasing this package, I was very glad to find it!

    I noted the error below when attempting to differentiate nufft1 with respect to points. I am very new to JAX so I could be mistaken, but I don't believe this is the intended behavior:

    from jax_finufft import nufft1, nufft2
    import numpy as np
    import jax.numpy as jnp
    from jax import grad
    M = 100000
    N = 200000
    
    x = 2 * np.pi * np.random.uniform(size=M)
    c = np.random.standard_normal(size=M) + 1j * np.random.standard_normal(size=M)
    
    def norm_nufft1(c,x):
        f = nufft1(N, c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    def norm_nufft2(c,x):
        f = nufft2( c, x, eps=1e-6, iflag=1)
        return jnp.linalg.norm(f)
    
    grad(norm_nufft2,argnums =(1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
    grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    grad(norm_nufft1,argnums =(1))(c,x) # Throws error
    

    The error is below:

    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in <module>
         22 grad(norm_nufft1,argnums =(0,))(c,x) # Works fine
         23 grad(norm_nufft1,argnums =(0,1))(c,x) # Works fine
    ---> 24 grad(norm_nufft1,argnums =(1))(c,x) # Throws error
         25 
         26 
    
        [... skipping hidden 10 frame]
    
    /var/folders/dg/zzj57d7d1gs1k9l8sdfh043m0000gn/T/ipykernel_13430/3134504975.py in norm_nufft1(c, x)
         11 
         12 def norm_nufft1(c,x):
    ---> 13     f = nufft1(N, c, x, eps=1e-6, iflag=1)
         14     return jnp.linalg.norm(f)
         15 
    
        [... skipping hidden 21 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in nufft1(output_shape, source, iflag, eps, *points)
         39 
         40     return jnp.reshape(
    ---> 41         nufft1_p.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps),
         42         expected_output_shape,
         43     )
    
        [... skipping hidden 3 frame]
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax_finufft/ops.py in jvp(type_, prim, args, tangents, output_shape, iflag, eps)
        248 
        249         axis = -2 if type_ == 2 else -ndim - 1
    --> 250         output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
        251 
        252         expand_shape = (
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in concatenate(arrays, axis)
       3405   if hasattr(arrays[0], "concatenate"):
       3406     return arrays[0].concatenate(arrays[1:], axis)
    -> 3407   axis = _canonicalize_axis(axis, ndim(arrays[0]))
       3408   arrays = _promote_dtypes(*arrays)
       3409   # lax.concatenate can be slow to compile for wide concatenations, so form a
    
    ~/opt/anaconda3/envs/alphafold/lib/python3.8/site-packages/jax/_src/util.py in canonicalize_axis(axis, num_dims)
        277   axis = operator.index(axis)
        278   if not -num_dims <= axis < num_dims:
    --> 279     raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
        280   if axis < 0:
        281     axis = axis + num_dims
    
    ValueError: axis -2 is out of bounds for array of dimension 1
    
    bug 
    opened by ma-gilles 4
  • Speed up `vmap`s where none of the points are batched

    Speed up `vmap`s where none of the points are batched

    If none of the points are batched in a vmap, we should be able to get a faster computation by stacking the transforms into a single transform and then reshaping. It might be worth making the stacked axes into an explicit parameter rather than just trying to infer it, but that would take some work on the interface.

    opened by dfm 0
  • WIP: Adding support for non-batched dimensions in `vmap`

    WIP: Adding support for non-batched dimensions in `vmap`

    In most cases, we'll just need to broadcast all the inputs out to the right shapes (this shouldn't be too hard), but when none of the points get mapped, we can get a bit of a speed up by stacking the transforms. This starts to implement that logic, but it's not quite ready yet.

    opened by dfm 0
  • Find a publication venue

    Find a publication venue

    @lgarrison and I have been chatting about the possibility of writing a paper describing what we're doing here. Something like "Differentiable programming with NUFFTs" or "NUFFTs for machine learning applications" or something. This issue is here to remind me to look into possible venues for this.

    opened by dfm 5
  • Starting to add GPU support using cuFINUFFT

    Starting to add GPU support using cuFINUFFT

    So far I just have the CMake definitions to compile cuFINUFFT when nvcc is found, but I haven't started writing the boilerplate needed to loop it into XLA. Coming soon!

    Keeping @lgarrison in the loop.

    opened by dfm 10
  • Support Type 3?

    Support Type 3?

    There aren't currently any plans to support the Type 3 transform for a few reasons:

    • I'm not totally sure of the use cases,
    • The logic will probably be somewhat more complicated than the existing implementations, and
    • cuFINUFFT doesn't seem to support Type 3.

    Do we want to work around these issues?

    opened by dfm 0
  • Add support for handling errors

    Add support for handling errors

    Currently, if any of the finufft methods fail with a non-zero error we just ignore it and keep on trucking. How should we handle this? It looks like the JAX convention is currently to set everything to NaN when an op fails, since propagating errors up from XLA can be a bit of a pain.

    opened by dfm 0
  • Add GPU support

    Add GPU support

    Via cufinufft. We'll need to figure out how to get XLA's handling of CUDA streams to play nice with cufinufft (this is way above my pay grade). Some references:

    1. A simple example of how to write an XLA compatible CUDA kernel: https://github.com/dfm/extending-jax/blob/main/lib/kernels.cc.cu
    2. The source code for how JAX wraps cuBLAS: https://github.com/google/jax/blob/main/jaxlib/cublas_kernels.cc
    3. There is a tensorflow implementation that might have some useful context: https://github.com/mrphys/tensorflow-nufft/blob/master/tensorflow_nufft/cc/kernels/nufft_kernels.cu.cc

    Perhaps @lgarrison is interested :D

    opened by dfm 0
Releases(v0.0.3)
  • v0.0.3(Dec 10, 2021)

    What's Changed

    • Fix segfault when batching multiple transforms by @dfm in https://github.com/dfm/jax-finufft/pull/11
    • Generalize the behavior of vmap by @dfm in https://github.com/dfm/jax-finufft/pull/12

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.2...v0.0.3

    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Nov 12, 2021)

    • Faster differentiation using stacked transforms
    • Better error checking for vmap

    Full Changelog: https://github.com/dfm/jax-finufft/compare/v0.0.1...v0.0.2

    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Nov 8, 2021)

Owner
Dan Foreman-Mackey
Dan Foreman-Mackey
Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation

Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation (CVPR2019) This is a pytorch implementatio

Yawei Luo 280 Jan 01, 2023
Code, Models and Datasets for OpenViDial Dataset

OpenViDial This repo contains downloading instructions for the OpenViDial dataset in 《OpenViDial: A Large-Scale, Open-Domain Dialogue Dataset with Vis

119 Dec 08, 2022
Transfer Learning Remote Sensing

Transfer_Learning_Remote_Sensing Simulation R codes for data generation and visualizations are in the folder simulation. Experiment: California Housin

2 Jun 21, 2022
It is a simple library to speed up CLIP inference up to 3x (K80 GPU)

CLIP-ONNX It is a simple library to speed up CLIP inference up to 3x (K80 GPU) Usage Install clip-onnx module and requirements first. Use this trick !

Gerasimov Maxim 93 Dec 20, 2022
Unofficial implementation of "Coordinate Attention for Efficient Mobile Network Design"

Unofficial implementation of "Coordinate Attention for Efficient Mobile Network Design". CoordAttention tensorflow slim

Billy 9 Aug 22, 2022
Motion planning algorithms commonly used on autonomous vehicles. (path planning + path tracking)

Overview This repository implemented some common motion planners used on autonomous vehicles, including Hybrid A* Planner Frenet Optimal Trajectory Hi

Huiming Zhou 1k Jan 09, 2023
Code repository for "Free View Synthesis", ECCV 2020.

Free View Synthesis Code repository for "Free View Synthesis", ECCV 2020. Setup Install the following Python packages in your Python environment - num

Intelligent Systems Lab Org 253 Dec 07, 2022
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

25 Dec 08, 2022
A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks

SVHNClassifier-PyTorch A PyTorch implementation of Multi-digit Number Recognition from Street View Imagery using Deep Convolutional Neural Networks If

Potter Hsu 182 Jan 03, 2023
A Real-ESRGAN equipped Colab notebook for CLIP Guided Diffusion

#360Diffusion automatically upscales your CLIP Guided Diffusion outputs using Real-ESRGAN. Latest Update: Alpha 1.61 [Main Branch] - 01/11/22 Layout a

78 Nov 02, 2022
Computer Vision Script to recognize first person motion, developed as final project for the course "Machine Learning and Deep Learning"

Overview of The Code BaseColab/MLDL_FPAR.pdf: it contains the full explanation of our work Base Colab: it contains the base colab used to perform all

Simone Papicchio 4 Jul 16, 2022
Leaderboard, taxonomy, and curated list of few-shot object detection papers.

Leaderboard, taxonomy, and curated list of few-shot object detection papers.

Gabriel Huang 70 Jan 07, 2023
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Phil Wang 40 Dec 22, 2022
This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et al. 2020

README This is the official Pytorch implementation of "Lung Segmentation from Chest X-rays using Variational Data Imputation", Raghavendra Selvan et a

Raghav 42 Dec 15, 2022
Oriented Object Detection: Oriented RepPoints + Swin Transformer/ReResNet

Oriented RepPoints for Aerial Object Detection The code for the implementation of “Oriented RepPoints + Swin Transformer/ReResNet”. Introduction Based

96 Dec 13, 2022
Forest R-CNN: Large-Vocabulary Long-Tailed Object Detection and Instance Segmentation (ACM MM 2020)

Forest R-CNN: Large-Vocabulary Long-Tailed Object Detection and Instance Segmentation (ACM MM 2020) Official implementation of: Forest R-CNN: Large-Vo

Jialian Wu 54 Jan 06, 2023
Cards Against Humanity AI

cah-ai This is a Cards Against Humanity AI implemented using a pre-trained Semantic Search model. How it works A player is described by a combination

Alex Nichol 2 Aug 22, 2022
Simple transformer model for CIFAR10

CIFAR-Transformer Simple transformer model for CIFAR10. Reference: https://www.tensorflow.org/text/tutorials/transformer https://github.com/huggingfac

9 Nov 07, 2022
Provably Rare Gem Miner.

Provably Rare Gem Miner just another random project by yoyoismee.eth useful link main site market contract useful thing you should know read contract

34 Nov 22, 2022
Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are implemented and can be seen in tensorboard.

Sarus published models Sarus implementation of classical ML models. The models are implemented using the Keras API of tensorflow 2. Vizualization are

Sarus Technologies 39 Aug 19, 2022