Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Overview

Pre-trained image classification models for Jax/Haiku

Jax/Haiku Applications are deep learning models that are made available alongside pre-trained weights. These models can be used for prediction, feature extraction, and fine-tuning.

Available Models

  • MobileNetV1
  • ResNet, ResNetV2
  • VGG16, VGG19
  • Xception

Planned Releases

  • MobileNetV2, MobileNetV3
  • InceptionResNetV2, InceptionV3
  • EfficientNetV1, EfficientNetV2

Installation

Haikumodels require Python 3.7 or later.

  1. Needed libraries can be installed using "installation.txt".
  2. If Jax GPU support desired, must be installed seperately according to system needs.

Usage examples for image classification models

Classify ImageNet classes with ResNet50

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)


def _model(images, is_training):
  net = hm.ResNet50()
  return net(images, is_training)


model = hk.transform_with_state(_model)

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.resnet.preprocess_input(x)

params, state = model.init(rng, x, is_training=True)

preds, _ = model.apply(params, state, None, x, is_training=False)
# decode the results into a list of tuples (class, description, probability)
# (one such list for each sample in the batch)
print("Predicted:", hm.decode_predictions(preds, top=3)[0])
# Predicted:
# [('n02504013', 'Indian_elephant', 0.8784022),
# ('n01871265', 'tusker', 0.09620289),
# ('n02504458', 'African_elephant', 0.025362419)]

Extract features with VGG16

import haiku as hk
import jax
import jax.numpy as jnp
from PIL import Image

import haikumodels as hm

rng = jax.random.PRNGKey(42)

model = hk.without_apply_rng(hk.transform(hm.VGG16(include_top=False)))

img_path = "elephant.jpg"
img = Image.open(img_path).resize((224, 224))

x = jnp.asarray(img, dtype=jnp.float32)
x = jnp.expand_dims(x, axis=0)
x = hm.vgg.preprocess_input(x)

params = model.init(rng, x)

features = model.apply(params, x)

Fine-tune Xception on a new set of classes

from typing import Callable, Any, Sequence, Optional

import optax
import haiku as hk
import jax
import jax.numpy as jnp

import haikumodels as hm

rng = jax.random.PRNGKey(42)


class Freezable_TrainState(NamedTuple):
  trainable_params: hk.Params
  non_trainable_params: hk.Params
  state: hk.State
  opt_state: optax.OptState


# create your custom top layers and include the desired pretrained model
class ft_xception(hk.Module):

  def __init__(
      self,
      classes: int,
      classifier_activation: Callable[[jnp.ndarray],
                                      jnp.ndarray] = jax.nn.softmax,
      with_bias: bool = True,
      w_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      b_init: Callable[[Sequence[int], Any], jnp.ndarray] = None,
      name: Optional[str] = None,
  ):
    super().__init__(name=name)
    self.classifier_activation = classifier_activation

    self.xception_no_top = hm.Xception(include_top=False)
    self.dense_layer = hk.Linear(
        output_size=1024,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_dense_layer",
    )
    self.top_layer = hk.Linear(
        output_size=classes,
        with_bias=with_bias,
        w_init=w_init,
        b_init=b_init,
        name="trainable_top_layer",
    )

  def __call__(self, inputs: jnp.ndarray, is_training: bool):
    out = self.xception_no_top(inputs, is_training)
    out = jnp.mean(out, axis=(1, 2))
    out = self.dense_layer(out)
    out = jax.nn.relu(out)
    out = self.top_layer(out)
    out = self.classifier_activation(out)


# use `transform_with_state` if models has batchnorm in it
# else use `transform` and then `without_apply_rng`
def _model(images, is_training):
  net = ft_xception(classes=200)
  return net(images, is_training)


model = hk.transform_with_state(_model)

# create your desired optimizer using Optax or alternatives
opt = optax.rmsprop(learning_rate=1e-4, momentum=0.90)


# this function will initialize params and state
# use the desired keyword to divide params to trainable and non_trainable
def initial_state(x_y, nonfreeze_key="trainable"):
  x, _ = x_y
  params, state = model.init(rng, x, is_training=True)

  trainable_params, non_trainable_params = hk.data_structures.partition(
      lambda m, n, p: nonfreeze_key in m, params)

  opt_state = opt.init(params)

  return Freezable_TrainState(trainable_params, non_trainable_params, state,
                              opt_state)


train_state = initial_state(next(gen_x_y))


# create your own custom loss function as desired
def loss_function(trainable_params, non_trainable_params, state, x_y):
  x, y = x_y
  params = hk.data_structures.merge(trainable_params, non_trainable_params)
  y_, state = model.apply(params, state, None, x, is_training=True)

  cce = categorical_crossentropy(y, y_)

  return cce, state


# to update params and optimizer, a train_step function must be created
@jax.jit
def train_step(train_state: Freezable_TrainState, x_y):
  trainable_params, non_trainable_params, state, opt_state = train_state
  trainable_params_grads, _ = jax.grad(loss_function,
                                       has_aux=True)(trainable_params,
                                                     non_trainable_params,
                                                     state, x_y)

  updates, new_opt_state = opt.update(trainable_params_grads, opt_state)
  new_trainable_params = optax.apply_updates(trainable_params, updates)

  train_state = Freezable_TrainState(new_trainable_params, non_trainable_params,
                                     state, new_opt_state)
  return train_state


# train the model on the new data for few epochs
train_state = train_step(train_state, next(gen_x_y))

# after training is complete it possible to merge
# trainable and non_trainable params to use for prediction
trainable_params, non_trainable_params, state, _ = train_state
params = hk.data_structures.merge(trainable_params, non_trainable_params)
preds, _ = model.apply(params, state, None, x, is_training=False)
You might also like...
3D ResNet Video Classification accelerated by TensorRT
3D ResNet Video Classification accelerated by TensorRT

Activity Recognition TensorRT Perform video classification using 3D ResNets trained on Kinetics-400 dataset and accelerated with TensorRT P.S Click on

improvement of CLIP features over the traditional resnet features on the visual question answering, image captioning, navigation and visual entailment tasks.

CLIP-ViL In our paper "How Much Can CLIP Benefit Vision-and-Language Tasks?", we show the improvement of CLIP features over the traditional resnet fea

PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT.

MoCo v3 for Self-supervised ResNet and ViT Introduction This is a PyTorch implementation of MoCo v3 for self-supervised ResNet and ViT. The original M

Reproduces ResNet-V3 with pytorch
Reproduces ResNet-V3 with pytorch

ResNeXt.pytorch Reproduces ResNet-V3 (Aggregated Residual Transformations for Deep Neural Networks) with pytorch. Tried on pytorch 1.6 Trains on Cifar

DeepLab resnet v2 model in pytorch

pytorch-deeplab-resnet DeepLab resnet v2 model implementation in pytorch. The architecture of deepLab-ResNet has been replicated exactly as it is from

Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet
Reproduce ResNet-v2(Identity Mappings in Deep Residual Networks) with MXNet

Reproduce ResNet-v2 using MXNet Requirements Install MXNet on a machine with CUDA GPU, and it's better also installed with cuDNN v5 Please fix the ran

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

NFT-Price-Prediction-CNN - Using visual feature extraction, prices of NFTs are predicted via CNN (Alexnet and Resnet) architectures.

In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.
In this project we use both Resnet and Self-attention layer for cat, dog and flower classification.

cdf_att_classification classes = {0: 'cat', 1: 'dog', 2: 'flower'} In this project we use both Resnet and Self-attention layer for cdf-Classification.

Comments
  • Expected top-1 test accuracy

    Expected top-1 test accuracy

    Hi

    This is a fantastic project! The released checkpoints are super helpful!

    I am wondering what's the top-1 test accuracy that one should get using the released ResNet-50 checkpoints. I am able to reach 0.749 using the my own ImageNet dataloader implemented via Tensorflow Datasets. Is the number close to your results?

    BTW, it would also be very helpful if you could release your training and dataloading code for these models!

    Thanks,

    opened by xidulu 2
  • Fitting issue

    Fitting issue

    I was trying to use a few of your pre-trained models, in particular the ResNet50 and VGG16 for features extraction, but unfortunately I didn't manage to fit on the Nvidia Titan X with 12GB of VRAM my question is which GPU did you use for training, how much VRAM I need for use them?

    For the VGG16 the system was asking me for 4 more GB and for the ResNet50 about 20 more

    Thanks.

    opened by mattiadutto 1
Owner
Alper Baris CELIK
Alper Baris CELIK
Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods”

Uncertainty Estimation Methods Code for the paper “The Peril of Popular Deep Learning Uncertainty Estimation Methods” Reference If you use this code,

EPFL Machine Learning and Optimization Laboratory 4 Apr 05, 2022
A time series processing library

Timeseria Timeseria is a time series processing library which aims at making it easy to handle time series data and to build statistical and machine l

Stefano Alberto Russo 11 Aug 08, 2022
This is an official implementation for "AS-MLP: An Axial Shifted MLP Architecture for Vision".

AS-MLP architecture for Image Classification Model Zoo Image Classification on ImageNet-1K Network Resolution Top-1 (%) Params FLOPs Throughput (image

SVIP Lab 106 Dec 12, 2022
Official implementation for "Low-light Image Enhancement via Breaking Down the Darkness"

Low-light Image Enhancement via Breaking Down the Darkness by Qiming Hu, Xiaojie Guo. 1. Dependencies Python3 PyTorch=1.0 OpenCV-Python, TensorboardX

Qiming Hu 30 Jan 01, 2023
Mahadi-Now - This Is Pakistani Just Now Login Tools

PAKISTANI JUST NOW LOGIN TOOLS Install apt update apt upgrade apt install python

MAHADI HASAN AFRIDI 19 Apr 06, 2022
All the essential resources and template code needed to understand and practice data structures and algorithms in python with few small projects to demonstrate their practical application.

Data Structures and Algorithms Python INDEX 1. Resources - Books Data Structures - Reema Thareja competitiveCoding Big-O Cheat Sheet DAA Syllabus Inte

Shushrut Kumar 129 Dec 15, 2022
Code for SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021)

SyncTwin: Treatment Effect Estimation with Longitudinal Outcomes (NeurIPS 2021) SyncTwin is a treatment effect estimation method tailored for observat

Zhaozhi Qian 3 Nov 03, 2022
Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering

Graph ConvNets in PyTorch October 15, 2017 Xavier Bresson http://www.ntu.edu.sg/home/xbresson https://github.com/xbresson https://twitter.com/xbresson

Xavier Bresson 287 Jan 04, 2023
[ICLR 2021] Rank the Episodes: A Simple Approach for Exploration in Procedurally-Generated Environments.

[ICLR 2021] RAPID: A Simple Approach for Exploration in Reinforcement Learning This is the Tensorflow implementation of ICLR 2021 paper Rank the Episo

Daochen Zha 48 Nov 21, 2022
[TPAMI 2021] iOD: Incremental Object Detection via Meta-Learning

Incremental Object Detection via Meta-Learning To appear in an upcoming issue of the IEEE Transactions on Pattern Analysis and Machine Intelligence (T

Joseph K J 66 Jan 04, 2023
ContourletNet: A Generalized Rain Removal Architecture Using Multi-Direction Hierarchical Representation

ContourletNet: A Generalized Rain Removal Architecture Using Multi-Direction Hierarchical Representation (Accepted by BMVC'21) Abstract: Images acquir

10 Dec 08, 2022
A state of the art of new lightweight YOLO model implemented by TensorFlow 2.

CSL-YOLO: A New Lightweight Object Detection System for Edge Computing This project provides a SOTA level lightweight YOLO called "Cross-Stage Lightwe

Miles Zhang 54 Dec 21, 2022
Official code for Next Check-ins Prediction via History and Friendship on Location-Based Social Networks (MDM 2018)

MUC Next Check-ins Prediction via History and Friendship on Location-Based Social Networks (MDM 2018) Performance Details for Accuracy: | Dataset

Yijun Su 3 Oct 09, 2022
This package implements the algorithms introduced in Smucler, Sapienza, and Rotnitzky (2020) to compute optimal adjustment sets in causal graphical models.

optimaladj: A library for computing optimal adjustment sets in causal graphical models This package implements the algorithms introduced in Smucler, S

Facundo Sapienza 6 Aug 04, 2022
Adversarial Texture Optimization from RGB-D Scans (CVPR 2020).

AdversarialTexture Adversarial Texture Optimization from RGB-D Scans (CVPR 2020). Scanning Data Download Please refer to data directory for details. B

Jingwei Huang 153 Nov 28, 2022
Pytorch implementation for RelTransformer

RelTransformer Our Architecture This is a Pytorch implementation for RelTransformer The implementation for Evaluating on VG200 can be found here Requi

Vision CAIR Research Group, KAUST 21 Nov 22, 2022
Free-duolingo-plus - Duolingo account creator that uses your invite code to get you free duolingo plus

free-duolingo-plus duolingo account creator that uses your invite code to get yo

1 Jan 06, 2022
X-modaler is a versatile and high-performance codebase for cross-modal analytics.

X-modaler X-modaler is a versatile and high-performance codebase for cross-modal analytics. This codebase unifies comprehensive high-quality modules i

910 Dec 28, 2022
Capsule endoscopy detection DACON challenge

capsule_endoscopy_detection (DACON Challenge) Overview Yolov5, Yolor, mmdetection기반의 모델을 사용 (총 11개 모델 앙상블) 모든 모델은 학습 시 Pretrained Weight을 yolov5, yolo

MAILAB 11 Nov 25, 2022
TRACER: Extreme Attention Guided Salient Object Tracing Network implementation in PyTorch

TRACER: Extreme Attention Guided Salient Object Tracing Network This paper was accepted at AAAI 2022 SA poster session. Datasets All datasets are avai

Karel 118 Dec 29, 2022