2023-11-12 18:01:12 -05:00

61 lines
2.2 KiB
Python

import tempfile
import warnings
from pathlib import Path
from export.models.constants import MCLIP_TO_OPENCLIP
import torch
from multilingual_clip.pt_multilingual_clip import MultilingualCLIP
from transformers import AutoTokenizer
from .openclip import to_onnx as openclip_to_onnx
from .optimize import optimize
from .util import get_model_path, clean_name
def to_onnx(
model_name: str,
output_dir_visual: Path | str,
output_dir_textual: Path | str,
) -> None:
textual_path = get_model_path(output_dir_textual)
with tempfile.TemporaryDirectory() as tmpdir:
model = MultilingualCLIP.from_pretrained(model_name, cache_dir=tmpdir)
AutoTokenizer.from_pretrained(model_name).save_pretrained(output_dir_textual)
for param in model.parameters():
param.requires_grad_(False)
export_text_encoder(model, textual_path)
openclip_to_onnx(MCLIP_TO_OPENCLIP[clean_name(model_name)], output_dir_visual)
optimize(textual_path)
def export_text_encoder(model: MultilingualCLIP, output_path: Path | str) -> None:
output_path = Path(output_path)
def forward(self: MultilingualCLIP, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
embs = self.transformer(input_ids, attention_mask)[0]
embs = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
embs = self.LinearTransformation(embs)
return torch.nn.functional.normalize(embs, dim=-1)
# unfortunately need to monkeypatch for tracing to work here
# otherwise it hits the 2GiB protobuf serialization limit
MultilingualCLIP.forward = forward
args = (torch.ones(1, 77, dtype=torch.int32), torch.ones(1, 77, dtype=torch.int32))
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
torch.onnx.export(
model,
args,
output_path.as_posix(),
input_names=["input_ids", "attention_mask"],
output_names=["text_embedding"],
opset_version=17,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
},
)