Model Overview
A Keras model implementing the MixTransformer architecture to be used as a backbone for the SegFormer architecture. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub.
References:
- SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers # noqa: E501
- Based on the TensorFlow implementation from DeepVision # noqa: E501
Links
- MiT Quickstart Notebook
- MiT API Documentation
- KerasHub Beginner Guide
- KerasHub Model Publishing Guide
Installation
Keras and KerasHub can be installed with:
pip install -U -q keras-Hub
pip install -U -q keras
Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the Keras Getting Started page.
Presets
The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below. Here's the table formatted similarly to the given pattern:
Here's the updated table with the input resolutions included in the descriptions:
| Preset name | Parameters | Description | 
|---|---|---|
| mit_b0_ade20k_512 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | 
| mit_b1_ade20k_512 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | 
| mit_b2_ade20k_512 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | 
| mit_b3_ade20k_512 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | 
| mit_b4_ade20k_512 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the ADE20K dataset with an input resolution of 512x512 pixels. | 
| mit_b5_ade20k_640 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the ADE20K dataset with an input resolution of 640x640 pixels. | 
| mit_b0_cityscapes_1024 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
| mit_b1_cityscapes_1024 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
| mit_b2_cityscapes_1024 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
| mit_b3_cityscapes_1024 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
| mit_b4_cityscapes_1024 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
| mit_b5_cityscapes_1024 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks, trained on the Cityscapes dataset with an input resolution of 1024x1024 pixels. | 
Example Usage
Using the class with a backbone:
import tensorflow as tf
import keras_cv
import numpy as np
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("mit_b4_ade20k_512")
# Evaluate model
model(images)
# Train model
model.compile(
     optimizer="adam",
     loss=keras.losses.BinaryCrossentropy(from_logits=False),
     metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
Example Usage with Hugging Face URI
Using the class with a backbone:
import tensorflow as tf
import keras_cv
import numpy as np
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/mit_b4_ade20k_512")
# Evaluate model
model(images)
# Train model
model.compile(
     optimizer="adam",
     loss=keras.losses.BinaryCrossentropy(from_logits=False),
     metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
- Downloads last month
- 2
