# CLIP - Contrastive Language-Image Pre-Training
This notebook will introduce the concept of multi-modality, that is when one model can handle multiple modalities of data e.g. images and language. You've most likely interacted with something like ChatGPT on images and receive a textual response. Indeed, ChatGPT is multi-modal.

The CLIP architecture handles language and image modalities. It does this via an embedding space, that is, a space where both text and images reside. The intuition behind an embedding space is that the input is transformed to an embedding i.e. a vector, where text and images which are similar also have embeddings which are similar in the embedding space. So, for example, an image of a gorilla and the sentence “an image of a gorilla" would end up having very similar embeddings and be close to each other in the vector space.

Authors: Albin Åberg Dahlberg, Stina Brunzell, Paul Häusner <br>
Last update: 05.12.2024

## Imports

In [None]:
import torch
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

import transformers

from PIL import Image

from urllib.request import urlopen
from urllib.error import URLError

import gzip
import pandas as pd

## We'll use a small sample of the WIT dataset - Wikipedia-based Image Text dataset
For those who want a closer look at how the data was curated: https://github.com/google-research-datasets/wit

The following code block downloads the sample into a dataframe


In [None]:
!wget https://storage.googleapis.com/gresearch/wit/wit_v1.train.all-1percent_sample.tsv.gz

file_path = "wit_v1.train.all-1percent_sample.tsv.gz"

# Unzip the file and read into a dataframe
with gzip.open(file_path, 'rt') as f:
    df = pd.read_csv(f, sep='\t')

df.info()

We'll only use english samples, and drop any rows with nan cells.

In [None]:
df = df.loc[df['language'] == 'en']
# Only some columns are of interest
dataset = df[['image_url', 'hierarchical_section_title',
              'context_page_description']].reset_index(drop=True).dropna()

These are the images we'll use

In [None]:
def plot_image(image_url):
  ''' Plots the image from an url '''
  try:
      with Image.open(urlopen(image_url)) as im:
          # The following fixes some problems when loading images:
          # https://stackoverflow.com/a/64598016
          image = im.convert("RGB")
  except (URLError, OSError):
      print("please provide a valid URL or local path")
  else:
      plt.imshow(np.asarray(image))
      plt.xticks([])
      plt.yticks([])
      plt.show()
      print()

# We select specific rows
INDECES = [6, 32, 51, 70, 82]
data = dataset.iloc[INDECES].reset_index(drop=True)

# Lets look at the images!
for i in range(len(data)):
  plot_image(data.iloc[i]['image_url'])

## Load in CLIP models
- `processor` transforms input to embeddings
- `model` predicts embeddings

In [None]:
model = transformers.CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = transformers.CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
def extract_image(image_url):
  ''' Extract images from urls to a list '''
  try:
      with Image.open(urlopen(image_url)) as im:
          # The following fixes some problems when loading images:
          # https://stackoverflow.com/a/64598016
          image = im.convert("RGB")
  except (URLError, OSError):
      print("please provide a valid URL or local path")

  return [image]

# We'll play a guessing game. Write a text that describes the image you see in `INPUT` and we'll see if it is more accurate than the other descriptions for the images!

In [None]:
def guessing_game(image_index, input):
  image_url = data.iloc[INDEX]['image_url']
  # Our image
  images = extract_image(image_url)
  # List for all descriptions
  descriptions = [input] + list(data['hierarchical_section_title'].values)

  with torch.no_grad():
      inputs = processor(text=descriptions, images=images, return_tensors="pt", padding=True, truncation=True)
      outputs = model(**inputs)

  dot_products_per_image = outputs.logits_per_image
  # Softmax on dot products to get prediction probabilities.
  probabilities = dot_products_per_image.softmax(dim=1).flatten()

  for i, desc in enumerate(descriptions):
    if len(desc) > 100:
      print(f"description: {desc[:100]}... --> p={probabilities[i]:.2f}")
    else:
      print(f"description: {desc} --> p={probabilities[i]:.2f}")

### Let's start with the first image
Write your guess in `INPUT` two blocks down

In [None]:
# Don't change the index
INDEX = 0
image_url = data.iloc[INDEX]['image_url']
plot_image(image_url)

In [None]:
# Write your description here
INPUT = "A frog"
guessing_game(INDEX, INPUT)

If your text input has a higher output than the other text inputs, then the model thinks your descriptions is more similar to the image! Do note that the output probabilities are in relation to each other, change your input and the probabilities will also change.

Next image

In [None]:
# Don't change the index
INDEX = 1
image_url = data.iloc[INDEX]['image_url']
plot_image(image_url)

In [None]:
# Write your description here
INPUT = ""
guessing_game(INDEX, INPUT)

Next image

In [None]:
# Don't change the index
INDEX = 2
image_url = data.iloc[INDEX]['image_url']
plot_image(image_url)

In [None]:
# Write your description here
INPUT = ""
guessing_game(INDEX, INPUT)

Next image

In [None]:
# Don't change the index
INDEX = 3
image_url = data.iloc[INDEX]['image_url']
plot_image(image_url)

In [None]:
# Write your description here
INPUT = ""
guessing_game(INDEX, INPUT)

Last one

In [None]:
# Don't change the index
INDEX = 4
image_url = data.iloc[INDEX]['image_url']
plot_image(image_url)

In [None]:
# Write your description here
INPUT = ""
guessing_game(INDEX, INPUT)