Skip to content
Sajal Sharma

Building an Image Classifier Really Fast Using Fastai

Machine Learning, Computer Vision, fastai3 min read



I recently started the course to build up my practical deep learning skills. In order to better retain what I learn, I'm going to be writing a series of posts/notebooks, implementing my own models based on the course content. This notebook is written based on what I learned from the first week of the course.

In this notebook we'll build an image classifier using the fastai, a deep learning library built on top of Pytorch that provides both high-level and low-level components to quickly build state-of-the-art models for common deep learning domains.

We'll build a model that can classify images of fruit into a binary category: rotten or not. You can imagine such a model being used inside refrigerators to detect if produce kept inside it has gone bad.

When I started learning ML in 2016, building such models was a non-trivial task. Libraries to build deep neural networks were still in their infancy (Pytorch was introduced in late 2016), and building accurate image classification models required a certain degree of specialized knowledge. All that has changed and, as you'll notice in the notebook, we can build an image classifier using just a few lines of code.

Let's get started!

1import os
2# !pip install -Uqq fastai duckduckgo_search

We'll be needing the duckduckgo_search package to quickly search for, and download images of rotten and fresh fruit to feed to our model. An advantage of using this library over other alternatives is that you don't need to set up an API key for basic usage.

Downloading images of rotten and fresh fruit

1from duckduckgo_search import ddg_images
2from fastcore.all import *
4def search_images(term, max_images=40):
5 """Searches for and returns images for a given term"""
6 print(f"Searching for '{term}'")
7 return L(ddg_images(term, max_results=max_images)).itemgot('image')
9urls = search_images('rotten fruit', max_images=1)
1Searching for 'rotten fruit'

Let's download this image and open it.

1from fastdownload import download_url
2dest = 'rotten.jpg'
3download_url(urls[0], dest, show_progress=False)
5from import *
6im =


Doing something similar for fresh fruit.

1download_url(search_images('fresh fruit', max_images=1)[0], 'fresh.jpg', show_progress=False)'fresh.jpg').to_thumb(256,256)
1Searching for 'fresh fruit'


Now that we know what duckduckgo image search is working fine, we can download images for both rotten and fresh fruit and store them in their respective directories. We use time.sleep to avoid spamming the search API.

1searches = 'rotten', 'fresh'
3from time import sleep
5for o in searches:
6 dest = (path/o)
7 dest.mkdir(exist_ok=True, parents=True)
8 download_images(dest, urls=search_images(f'{o} fruit'))
9 sleep(5) # Pause between searches to avoid over-loading server
10 download_images(dest, urls=search_images(f'{o} apple'))
11 sleep(5) # Pause between searches to avoid over-loading server
12 download_images(dest, urls=search_images(f'{o} banana'))
13 sleep(5) # Pause between searches to avoid over-loading server
14 download_images(dest, urls=search_images(f'{o} vegetables'))
15 resize_images(path/o, max_size=400, dest=path/o)
1Searching for 'rotten fruit'
2Searching for 'rotten apple'
3Searching for 'rotten banana'
4Searching for 'rotten vegetables'
5Searching for 'fresh fruit'
6Searching for 'fresh apple'
7Searching for 'fresh banana'
8Searching for 'fresh vegetables'

Training our model

We have our images and the next step is to train a model. Again, it blows my mind how simple this is using fastai. I'll briefly explain what the below blocks of code are doing.

First, we check if all image files can be opened correctly using a fastai vision library utility verify_images. If it can't be opened, then we unlink it from our path so that is is not used in model training.

1# validate images

Next, we'll use another building block from the fastai library, the DataBlock class, which we can use to represent our training data, the labels, data splitting criteria, and any data transformations.

blocks=(ImageBlock, CategoryBlock) is used to specify what kind of data is in the DataBlock. We have images, and categories - hence a tuple of ImageBlock and CategoryBlock classes.

get_items takes the function get_image_files as its parameter. get_image_files is used to find the paths of our input images.

splitter=RandomSplitter(valid_pct=0.2, seed=42) specifies that we want to randomly split our input data into training and validation sets, using 20% data for validation.

get_y=parent_label specifies that the labels for an image file is its parent (the directory that the file belongs to).

item_tfms=[Resize(192, method='squish')] specifies the transformation performed on each file. Here we are resizing each image to 192x192 pixels by squishing it. Another option could be to crop the image.

1dls = DataBlock(
2 blocks=(ImageBlock, CategoryBlock),
3 get_items=get_image_files,
4 splitter=RandomSplitter(valid_pct=0.2, seed=42),
5 get_y=parent_label,
6 item_tfms=[Resize(192, method='squish')]
7).dataloaders(path, bs=32)


Above you can see a batch of images from our DataBlock, along with their labels. This is a nice way of quickly knowing if a sample from our data is correct (images/labels).

To train our model we will fine-tune the resnet18, which is one of the most widely used computer vision models, on our dataset.

1clf = vision_learner(dls, resnet18, metrics=error_rate)

Using the model

It's finally time to use our model and see how it does predicting if a fruit is rotten or not.'rotten.jpg').to_thumb(256,256)


1is_rotten,_,probs = clf.predict(PILImage.create('rotten.jpg'))
2print(f"This fruit/vegetable is: {is_rotten}.")
3print(f"Probability it's rotten: {probs[1]:.4f}")
1This fruit/vegetable is: rotten.
2Probability it's rotten: 1.0000'fresh.jpg').to_thumb(256,256)


1is_rotten,_,probs = clf.predict(PILImage.create('fresh.jpg'))
2print(f"This fruit/vegetable is: {is_rotten}.")
3print(f"Probability it's fresh: {probs[0]:.4f}")
1This fruit/vegetable is: fresh.
2Probability it's fresh: 1.0000

Classifying images of rotten and fresh oranges

Let's see if our model can predict if a given image is of a rotten orange or a fresh orange. We haven't explicitly downloaded images of fresh/rotten oranges for our training set, so it would be a good generalization on "unseen data".

1download_url(search_images('fresh orange', max_images=1)[0], 'fresh orange.jpg', show_progress=False)'fresh orange.jpg').to_thumb(256,256)


1is_rotten,_,probs = clf.predict(PILImage.create('fresh orange.jpg'))
2print(f"This fruit/vegetable is: {is_rotten}.")
3print(f"Probability it's fresh: {probs[0]:.4f}")
1This fruit/vegetable is: fresh.
2Probability it's fresh: 0.9748

1download_url(search_images('rotten orange', max_images=1)[0], 'rotten orange.jpg', show_progress=False)'rotten orange.jpg').to_thumb(256,256)


1is_rotten,_,probs = clf.predict(PILImage.create('rotten orange.jpg'))
2print(f"This fruit/vegetable is: {is_rotten}.")
3print(f"Probability it's rotten: {probs[1]:.4f}")
1This fruit/vegetable is: rotten.
2Probability it's rotten: 0.9899

Not bad at all. The model seems to generalize fine. Though, a more accurate measure of generalizability would involve creating a separate test set and calculating performance metrics.


There you have it! With a few lines of code we have created our own image classification model by fine-tuning off the shelf models with fastai. The high level apis that the library provides makes the process of building an initial model a breeze. If you want to run the notebook for yourself, you can check it out on Kaggle here. I urge you to try building your own classification model on images from duckduckgo search.

I'll be going over the rest of the fastai course in the coming weeks. Even though I have only done the first two weeks till now, I highly recommend it for anyone interested in Machine Learning, more so for people with a coding background.


© 2022 Sajal Sharma.
Made with ❤️   +  GatsbyJS