Multimodal deep learning with DocArray
DocArray is a library for representing, sending, and storing multimodal data that can be used for a variety of different use cases.
Here we will focus on a workflow familiar to many ML engineers: Building and training a model, and then serving it to users.
This document contains two parts:
- Representing: We will use DocArray to represent multimodal data while building and training a PyTorch model. We will see how DocArray can help to organize and group your modalities and tensors and make clear what methods to expect as inputs and return as outputs.
- Sending: We will take the model that we built and trained in part one, and serve it using FastAPI. We will see how DocArray narrows the gap between model development and model deployment, and how the same data models can be reused in both contexts. That part will be very short, but that's the point!
So without further ado, let's dive into it!
1. Representing: Build and train a PyTorch model
We will train a CLIP-like model on a dataset composed of text-image pairs. The goal is to obtain a model that can understand both text and images and project them into a common embedding space.
We train the CLIP-like model on the Flickr8k dataset. To run this, you need to download and unzip the data into the same folder as your code.
Note
In this tutorial we do not aim to reproduce any CLIP results (our dataset is way too small anyway), but rather we want to show how DocArray data structures help researchers and practitioners write beautiful and pythonic multimodal PyTorch code.
#!pip install "docarray[torch,image]"
#!pip install torchvision
#!pip install transformers
#!pip install fastapi
#!pip install pandas
Create documents for handling multimodal data
The first thing we want to achieve when using DocArray is to clearly model our data so that we never get confused about which tensors represent what.
To do that we are using a concept that is at the core of DocArray: The document -- a collection of multimodal data.
The BaseDoc
class allows users to define their own (nested, multimodal) document schema to represent any kind of complex data.
Let's start by defining a few documents to handle the different modalities that we will use during our training:
Let's first create a document for our Text modality. It will contain a number of Tokens
, which we also define:
from docarray.documents import TextDoc as BaseText
class Tokens(BaseDoc):
input_ids: TorchTensor[48]
attention_mask: TorchTensor
Notice the TorchTensor
type. It is a thin wrapper around torch.Tensor
that can be used like any other Torch tensor,
but also enables additional features. One such feature is shape parametrization (TorchTensor[48]
), which lets you
hint and even enforce the desired shape of any tensor!
To represent our image data, we use DocArray's ImageDoc
:
Under the hood, an ImageDoc
looks something like this (with the only main difference that it can take tensors from any
supported ML framework):
class ImageDoc(BaseDoc):
url: Optional[ImageUrl] = None
tensor: Optional[TorchTesor] = None
embedding: Optional[TorchTensor] = None
Actually, the BaseText
above also already includes tensor
, url
and embedding
fields, so we can use those on our
Text
document as well.
The final document used for training here is the PairTextImage
, which simply combines the Text and Image modalities:
You then need to forward declare the following types. This will allow the objects to be properly pickled and unpickled.
This will be unnecessary once this issue is resolved.
Create the dataset
In this section we will create a multimodal pytorch dataset around the Flick8k dataset using DocArray.
We will use DocArray's data loading functionality to load the data and use Torchvision and Transformers to preprocess the data before feeding it to our deep learning model:
class VisionPreprocess:
def __init__(self):
self.transform = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Resize(232),
torchvision.transforms.RandomCrop(224),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def __call__(self, image: Image) -> None:
image.tensor = self.transform(image.url.load())
class TextPreprocess:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
def __call__(self, text: Text) -> None:
assert isinstance(text, Text)
text.tokens = Tokens(
**self.tokenizer(
text.text, padding="max_length", truncation=True, max_length=48
)
)
VisionPreprocess
and TextPreprocess
implement standard preprocessing steps for images and text, nothing special here.
import pandas as pd
def get_flickr8k_da(file: str = "captions.txt", N: Optional[int] = None):
df = pd.read_csv(file, nrows=N)
da = DocList[PairTextImage](
PairTextImage(text=Text(text=i.caption), image=Image(url=f"Images/{i.image}"))
for i in df.itertuples()
)
return da
In the get_flickr8k_da
method we process the Flickr8k dataset into a DocList
.
Now let's instantiate this dataset using the MultiModalDataset
class. The constructor takes in the da
and a dictionary of preprocessing transformations:
from docarray.data import MultiModalDataset
dataset = MultiModalDataset[PairTextImage](da=da, preprocessing=preprocessing)
loader = DataLoader(
dataset,
batch_size=128,
collate_fn=dataset.collate_fn,
shuffle=True,
num_workers=4,
multiprocessing_context="fork",
)
Create the Pytorch model that works on DocArray
In this section we will create two encoders, one per modality (Text and Image). These encoders are normal PyTorch nn.Module
s.
The only difference is that they operate on DocList
rather that on torch.Tensor:
class TextEncoder(nn.Module):
def __init__(self):
super().__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
def forward(self, texts: DocList[TextDoc]) -> TorchTensor:
last_hidden_state = self.bert(
input_ids=texts.tokens.input_ids, attention_mask=texts.tokens.attention_mask
).last_hidden_state
return self._mean_pool(last_hidden_state, texts.tokens.attention_mask)
def _mean_pool(
self, last_hidden_state: TorchTensor, attention_mask: TorchTensor
) -> TorchTensor:
masked_output = last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
The TextEncoder
takes a DocList
of TextDoc
s as input, and returns an embedding TorchTensor
as output.
DocList
can be seen as a list of TextDoc
documents, and the encoder will treat it as one batch.
class VisionEncoder(nn.Module):
def __init__(self):
super().__init__()
self.backbone = torchvision.models.resnet18(pretrained=True)
self.linear = nn.LazyLinear(out_features=768)
def forward(self, images: DocList[ImageDoc]) -> TorchTensor:
x = self.backbone(images.tensor)
return self.linear(x)
Similarly, the VisionEncoder
also takes a DocList
of ImageDoc
s as input, and returns an embedding TorchTensor
as output.
However, it operates on the tensor
attribute of each document.
Now we can instantiate our encoders:
As you can see, DocArray helps us clearly convey what data is expected as input and output for each method, all through Python type hints.
Train the model in a contrastive way between Text and Image (CLIP)
Now that we have defined our dataloader and our models, we can train the two encoders is a contrastive way. The goal is to match the representation of the text and the image for each pair in the dataset.
optim = torch.optim.Adam(
itertools.chain(vision_encoder.parameters(), text_encoder.parameters()), lr=3e-4
)
def cosine_sim(x_mat: TorchTensor, y_mat: TorchTensor) -> TorchTensor:
a_n, b_n = x_mat.norm(dim=1)[:, None], y_mat.norm(dim=1)[:, None]
a_norm = x_mat / torch.clamp(a_n, min=1e-7)
b_norm = y_mat / torch.clamp(b_n, min=1e-7)
return torch.mm(a_norm, b_norm.transpose(0, 1)).squeeze()
def clip_loss(image: DocList[Image], text: DocList[Text]) -> TorchTensor:
sims = cosine_sim(image.embedding, text.embedding)
return torch.norm(sims - torch.eye(sims.shape[0], device=DEVICE))
In the type hints of cosine_sim
and clip_loss
you can again notice that we can treat a TorchTensor
like any other
torch.Tensor
, and how we can make explicit what kind of data and data modalities the different functions expect.
One thing to notice here is that our dataloader does not return a torch.Tensor
but a DocList[PairTextImage]
,
which is exactly what our model can operate on.
So let's write a training loop and train our encoders:
from tqdm import tqdm
with torch.autocast(device_type="cuda", dtype=torch.float16):
for epoch in range(num_epoch):
for i, batch in tqdm(
enumerate(loader), total=len(loader), desc=f"Epoch {epoch}"
):
batch.to(DEVICE) # DocList can be moved to device
optim.zero_grad()
# FORWARD PASS:
batch.image.embedding = vision_encoder(batch.image)
batch.text.embedding = text_encoder(batch.text)
loss = clip_loss(batch.image, batch.text)
if i % 30 == 0:
print(f"{i+epoch} steps , loss : {loss}")
loss.backward()
optim.step()
Here we see how we can immediately group the output of each encoder with the document (and modality) it belong to.
And with all that, we've successfully trained a CLIP-like model without ever getting confused about the meaning of any tensors!
2. Sending: Serve the model using FastAPI
Now that we have a trained CLIP model, let's see how we can serve this model with a REST API by reusing most of the code above.
Let's use FastAPI for that!
FastAPI is powerful because it allows you to define your Rest API data schema in pure Python. And DocArray is fully compatible with FastAPI and Pydantic, which means that as long as you have a function that takes a document as input, FastAPI will be able to automatically translate it into a fully fledged API with documentation, OpenAPI specification and more:
Now all we need to do is to tell FastAPI what methods it should use to serve the model:
@app.post("/embed_text/", response_model=Text, response_class=DocumentResponse)
async def embed_text(doc: Text) -> Text:
with torch.autocast(device_type="cuda", dtype=torch.float16):
with torch.inference_mode():
text_preprocess(doc)
da = DocList[Text]([doc], tensor_type=TorchTensor).to_doc_vec()
da.to(DEVICE)
doc.embedding = text_encoder(da)[0].to('cpu')
return doc
You can see that our earlier definition of the Text
document now doubles as the API schema for the /embed_text
endpoint.
With this running, we can query our model over the network:
async with AsyncClient(
app=app,
base_url="http://test",
) as ac:
response = await ac.post("/embed_text/", data=text_input.json())
And we're done! You have trained and served a multimodal ML model, with zero headaches and a lot of DocArray!