diff --git a/ads/aqua/config/container_config.py b/ads/aqua/config/container_config.py index 3c45d192e..93d95993c 100644 --- a/ads/aqua/config/container_config.py +++ b/ads/aqua/config/container_config.py @@ -118,7 +118,7 @@ class AquaContainerConfig(Serializable): evaluate (Dict[str, AquaContainerConfigItem]): Evaluation container configuration items. """ - inference: Dict[str, AquaContainerConfigItem] = Field( + inference: Dict[str, List[AquaContainerConfigItem]] = Field( default_factory=dict, description="Inference container configuration items." ) finetune: Dict[str, AquaContainerConfigItem] = Field( @@ -130,7 +130,9 @@ class AquaContainerConfig(Serializable): def to_dict(self): return { - "inference": list(self.inference.values()), + "inference": [ + item for sublist in self.inference.values() for item in sublist + ], "finetune": list(self.finetune.values()), "evaluate": list(self.evaluate.values()), } @@ -149,12 +151,11 @@ def from_service_config( ------- AquaContainerConfig: The constructed container configuration. """ - - inference_items: Dict[str, AquaContainerConfigItem] = {} + inference_items: Dict[str, List[AquaContainerConfigItem]] = {} finetune_items: Dict[str, AquaContainerConfigItem] = {} evaluate_items: Dict[str, AquaContainerConfigItem] = {} for container in service_containers: - if not container.is_latest: + if "INFERENCE" not in container.usages and not container.is_latest: continue container_item = AquaContainerConfigItem( name=SERVICE_MANAGED_CONTAINER_URI_SCHEME + container.container_name, @@ -235,7 +236,9 @@ def from_service_config( ) if "INFERENCE" in usages or "MULTI_MODEL" in usages: - inference_items[container_type] = container_item + if container_type not in inference_items: + inference_items[container_type] = [] + inference_items[container_type].append(container_item) if "FINE_TUNE" in usages: finetune_items[container_type] = container_item if "EVALUATION" in usages: