mirror of
				https://github.com/immich-app/immich.git
				synced 2025-10-31 02:27:08 -04:00 
			
		
		
		
	tflite CLIP export
This commit is contained in:
		
							parent
							
								
									5f6ad9e239
								
							
						
					
					
						commit
						eb0f79b72e
					
				| @ -22,5 +22,5 @@ dependencies: | |||||||
|   - pip: |   - pip: | ||||||
|       - multilingual-clip |       - multilingual-clip | ||||||
|       - onnx-simplifier |       - onnx-simplifier | ||||||
|       - tensorflow |       - tensorflow==2.14.* | ||||||
| category: main | category: main | ||||||
|  | |||||||
| @ -13,20 +13,22 @@ class _CLIPWrapper(tf.Module): | |||||||
|         self.model = TFCLIPModel.from_pretrained(model_name) |         self.model = TFCLIPModel.from_pretrained(model_name) | ||||||
| 
 | 
 | ||||||
|     @tf.function() |     @tf.function() | ||||||
|     def encode_image(self, input): |     def encode_image(self, input_tensor): | ||||||
|         return self.model.get_image_features(input) |         return self.model.get_image_features(input_tensor) | ||||||
| 
 | 
 | ||||||
|     @tf.function() |     @tf.function() | ||||||
|     def encode_text(self, input): |     def encode_text(self, input_tensor): | ||||||
|         return self.model.get_text_features(input) |         return self.model.get_text_features(input_tensor) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # exported model signatures use batch size 2 because of the following reasons: | # exported model signatures use batch size 2 because of the following reasons: | ||||||
| # 1. ARM-NN cannot use dynamic batch sizes | # 1. ARM-NN cannot use dynamic batch sizes for complex models like CLIP ViT | ||||||
| # 2. batch size 1 creates a larger TF-Lite model that uses a lot (50%) more RAM | # 2. batch size 1 creates a larger TF-Lite model that uses a lot (50%) more RAM | ||||||
| # 3. batch size 2 is ~50% faster on GPU than 1 while 4 (or larger) are not faster | # 3. batch size 2 is ~50% faster on GPU than 1 while 4 (or larger) are not really faster | ||||||
| # 4. batch size >2 wastes more computation if only a single image is processed | # 4. batch size >2 wastes more computation if only a single image is processed | ||||||
| BATCH_SIZE = 2 | BATCH_SIZE_IMAGE = 2 | ||||||
|  | # On most small-scale systems there will only be one query at a time, no sense in batching | ||||||
|  | BATCH_SIZE_TEXT = 1 | ||||||
| 
 | 
 | ||||||
| SIGNATURE_TEXT = "encode_text" | SIGNATURE_TEXT = "encode_text" | ||||||
| SIGNATURE_IMAGE = "encode_image" | SIGNATURE_IMAGE = "encode_image" | ||||||
| @ -52,19 +54,19 @@ def _export_temporary_tf_model(model_name, tmp_path: str, context_length: int): | |||||||
|     wrapper = _CLIPWrapper(model_name) |     wrapper = _CLIPWrapper(model_name) | ||||||
|     conf = wrapper.model.config.vision_config |     conf = wrapper.model.config.vision_config | ||||||
|     spec_visual = tf.TensorSpec( |     spec_visual = tf.TensorSpec( | ||||||
|         shape=(BATCH_SIZE, conf.num_channels, conf.image_size, conf.image_size), dtype=tf.float32 |         shape=(BATCH_SIZE_IMAGE, conf.num_channels, conf.image_size, conf.image_size), dtype=tf.float32 | ||||||
|     ) |     ) | ||||||
|     encode_image = wrapper.encode_image.get_concrete_function(spec_visual) |     encode_image = wrapper.encode_image.get_concrete_function(spec_visual) | ||||||
|     spec_text = tf.TensorSpec(shape=(BATCH_SIZE, context_length), dtype=tf.int32) |     spec_text = tf.TensorSpec(shape=(BATCH_SIZE_TEXT, context_length), dtype=tf.int32) | ||||||
|     encode_text = wrapper.encode_text.get_concrete_function(spec_text) |     encode_text = wrapper.encode_text.get_concrete_function(spec_text) | ||||||
|     signatures = {"encode_text": encode_text, "encode_image": encode_image} |     signatures = {SIGNATURE_IMAGE: encode_image, SIGNATURE_TEXT: encode_text} | ||||||
|     tf.saved_model.save(wrapper, tmp_path, signatures) |     tf.saved_model.save(wrapper, tmp_path, signatures) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _export_tflite_model(tmp_path: str, signature: str, output_path: str): | def _export_tflite_model(tmp_path: str, signature: str, output_path: str): | ||||||
|     converter = tf.lite.TFLiteConverter.from_saved_model(tmp_path, signature_keys=[signature]) |     converter = tf.lite.TFLiteConverter.from_saved_model(tmp_path, signature_keys=[signature]) | ||||||
|     converter.optimizations = [tf.lite.Optimize.DEFAULT] |     converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||||||
|     converter.target_spec.supported_types = [tf.float32] |     converter.target_spec.supported_types = [tf.float16] | ||||||
|     tflite_model = converter.convert() |     tflite_model = converter.convert() | ||||||
|     with open(output_path, "wb") as f: |     with open(output_path, "wb") as f: | ||||||
|         f.write(tflite_model) |         f.write(tflite_model) | ||||||
|  | |||||||
| @ -4,9 +4,10 @@ from pathlib import Path | |||||||
| from tempfile import TemporaryDirectory | from tempfile import TemporaryDirectory | ||||||
| 
 | 
 | ||||||
| from huggingface_hub import create_repo, login, upload_folder | from huggingface_hub import create_repo, login, upload_folder | ||||||
| from models import mclip, openclip, tfclip |  | ||||||
| from rich.progress import Progress | from rich.progress import Progress | ||||||
| 
 | 
 | ||||||
|  | from models import mclip, openclip, tfclip | ||||||
|  | 
 | ||||||
| models = [ | models = [ | ||||||
|     "RN50::openai", |     "RN50::openai", | ||||||
|     "RN50::yfcc15m", |     "RN50::yfcc15m", | ||||||
|  | |||||||
							
								
								
									
										36
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										36
									
								
								machine-learning/poetry.lock
									
									
									
										generated
									
									
									
								
							| @ -1,4 +1,4 @@ | |||||||
| # This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. | # This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. | ||||||
| 
 | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "aiocache" | name = "aiocache" | ||||||
| @ -3882,6 +3882,30 @@ files = [ | |||||||
| [package.dependencies] | [package.dependencies] | ||||||
| mpmath = ">=0.19" | mpmath = ">=0.19" | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "tflite-runtime" | ||||||
|  | version = "2.14.0" | ||||||
|  | description = "TensorFlow Lite is for mobile and embedded devices." | ||||||
|  | optional = false | ||||||
|  | python-versions = "*" | ||||||
|  | files = [ | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:bb11df4283e281cd609c621ac9470ad0cb5674408593272d7593a2c6bde8a808"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp310-cp310-manylinux_2_34_aarch64.whl", hash = "sha256:d38c6885f5e9673c11a61ccec5cad7c032ab97340718d26b17794137f398b780"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp310-cp310-manylinux_2_34_armv7l.whl", hash = "sha256:7fe33f763263d1ff2733a09945a7547ab063d8bc311fd2a1be8144d850016ad3"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:195ab752e7e57329a68e54dd3dd5439fad888b9bff1be0f0dc042a3237a90e4d"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp311-cp311-manylinux_2_34_aarch64.whl", hash = "sha256:ce9fa5d770a9725c746dcbf6f59f3178233b3759f09982e8b2db8d2234c333b0"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp311-cp311-manylinux_2_34_armv7l.whl", hash = "sha256:c4e66a74165b18089c86788400af19fa551768ac782d231a9beae2f6434f7949"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:9f965054467f7890e678943858c6ac76a5197b17f61b48dcbaaba0af41d541a7"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp38-cp38-manylinux_2_34_aarch64.whl", hash = "sha256:437167fe3d8b12f50f5d694da8f45d268ab84a495e24c3dd810e02e1012125de"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp38-cp38-manylinux_2_34_armv7l.whl", hash = "sha256:79d8e17f68cc940df7e68a177b22dda60fcffba195fb9dd908d03724d65fd118"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:4aa740210a0fd9e4db4a46e9778914846b136e161525681b41575ca4896158fb"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp39-cp39-manylinux_2_34_aarch64.whl", hash = "sha256:be198b7dc4401204be54a15884d9e336389790eb707439524540f5a9329fdd02"}, | ||||||
|  |     {file = "tflite_runtime-2.14.0-cp39-cp39-manylinux_2_34_armv7l.whl", hash = "sha256:eca7672adca32727bbf5c0f1caf398fc17bbe222f2a684c7a2caea6fc6767203"}, | ||||||
|  | ] | ||||||
|  | 
 | ||||||
|  | [package.dependencies] | ||||||
|  | numpy = ">=1.23.2" | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "threadpoolctl" | name = "threadpoolctl" | ||||||
| version = "3.2.0" | version = "3.2.0" | ||||||
| @ -4025,6 +4049,14 @@ dev = ["tokenizers[testing]"] | |||||||
| docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] | docs = ["setuptools_rust", "sphinx", "sphinx_rtd_theme"] | ||||||
| testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] | testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"] | ||||||
| 
 | 
 | ||||||
|  | [[package]] | ||||||
|  | name = "torch" | ||||||
|  | version = "2.0.1" | ||||||
|  | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" | ||||||
|  | optional = false | ||||||
|  | python-versions = "*" | ||||||
|  | files = [] | ||||||
|  | 
 | ||||||
| [[package]] | [[package]] | ||||||
| name = "torch" | name = "torch" | ||||||
| version = "2.1.0" | version = "2.1.0" | ||||||
| @ -4772,4 +4804,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] | |||||||
| [metadata] | [metadata] | ||||||
| lock-version = "2.0" | lock-version = "2.0" | ||||||
| python-versions = "^3.11" | python-versions = "^3.11" | ||||||
| content-hash = "bba5f87aa67bc1d2283a9f4b471ef78e572337f22413870d324e908014410d53" | content-hash = "56614afdeeeec3b7f0b786771a8fcc126761c882b1033664056042833767e521" | ||||||
|  | |||||||
| @ -29,6 +29,7 @@ python-multipart = "^0.0.6" | |||||||
| orjson = "^3.9.5" | orjson = "^3.9.5" | ||||||
| safetensors = "0.3.2" | safetensors = "0.3.2" | ||||||
| gunicorn = "^21.1.0" | gunicorn = "^21.1.0" | ||||||
|  | tflite-runtime = "^2.14.0" | ||||||
| 
 | 
 | ||||||
| [tool.poetry.group.dev.dependencies] | [tool.poetry.group.dev.dependencies] | ||||||
| mypy = "^1.3.0" | mypy = "^1.3.0" | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user