Image classification using PyTorch with AlexNet

This tutorial explains how to use pre trained models with PyTorch. We will use AlexNet pre trained model for prediction labels for input image.

Prerequisites
  • Execute code snippets in this article on Google Colab Notebooks
  • Download imagenet classes from this link and place in /content directory in colab notebook
  • Download sample image from this link and place in /content directory in colab notebook
  • Install torch and torchvision

    
    
    pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
    
    
    

    Instanitate AlexNet model

    
    import torch
    from torchvision import models
    from torchvision import transforms
    from PIL import Image
    
    alexnet = models.alexnet(pretrained=True)
    
    
    

    Pre-process input image for AlexNet model

    
    
    preprocess_image = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(
      mean=[0.485, 0.456, 0.406],
      std=[0.229, 0.224, 0.225]
    )])
    
    
    image = Image.open("./cat3.png")
    image_tensor = preprocess_image(image)
    
    
    

    Create input tensor from image tensor, by adding one additional dimension

    
    print(image_tensor.shape)
    
    input_tensor = torch.unsqueeze(image_tensor, 0)
    print(input_tensor.shape)
    
    

    Output

    
    torch.Size([3, 224, 224])
    torch.Size([1, 3, 224, 224])
    
    

    Evaluate model and get inference tensor

    
    alexnet.eval()
    
    prediction_tensor = alexnet(input_tensor)
    print(prediction_tensor.shape)
    
    
    

    Output

    
    torch.Size([1, 1000])
    
    

    Create list of labels from imagenet_classes file.

    
    
    with open('./imagenet_classes.txt') as f:
      labels = [line.strip() for line in f.readlines()]
    
    
    

    Get index and image label

    
    max_value, index_of_max_value = torch.max(prediction_tensor, 1)
    print(index_of_max_value.numpy())
    
    predicted_label = labels[index_of_max_value]
    print(predicted_label)
    
    

    Output

    
    tabby, tabby cat
    
    

    Category: PyTorch