Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # Copyright (c) 2023-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """Engine for testing.""" | |
| import time | |
| from tokenize_anything.build_model import model_registry | |
| class InferenceCommand(object): | |
| """Command to run batched inference.""" | |
| def __init__(self, input_queue, output_queue, kwargs): | |
| self.input_queue = input_queue | |
| self.output_queue = output_queue | |
| self.kwargs = kwargs | |
| def build_env(self): | |
| """Build the environment.""" | |
| self.batch_size = self.kwargs.get("batch_size", 1) | |
| self.batch_timeout = self.kwargs.get("batch_timeout", None) | |
| def build_model(self): | |
| """Build and return the model.""" | |
| builder = model_registry[self.kwargs["model_type"]] | |
| return builder(device=self.kwargs["device"], checkpoint=self.kwargs["weights"]) | |
| def build_predictor(self, model): | |
| """Build and return the predictor.""" | |
| return self.kwargs["predictor_type"](model, self.kwargs) | |
| def send_results(self, predictor, indices, examples): | |
| """Send the inference results.""" | |
| results = predictor.get_results(examples) | |
| if hasattr(predictor, "timers"): | |
| time_diffs = dict((k, v.average_time) for k, v in predictor.timers.items()) | |
| for i, outputs in enumerate(results): | |
| self.output_queue.put((indices[i], time_diffs, outputs)) | |
| else: | |
| for i, outputs in enumerate(results): | |
| self.output_queue.put((indices[i], outputs)) | |
| def run(self): | |
| """Main loop to make the inference outputs.""" | |
| self.build_env() | |
| model = self.build_model() | |
| predictor = self.build_predictor(model) | |
| must_stop = False | |
| while not must_stop: | |
| indices, examples = [], [] | |
| deadline, timeout = None, None | |
| for i in range(self.batch_size): | |
| if self.batch_timeout and i == 1: | |
| deadline = time.monotonic() + self.batch_timeout | |
| if self.batch_timeout and i >= 1: | |
| timeout = deadline - time.monotonic() | |
| try: | |
| index, example = self.input_queue.get(timeout=timeout) | |
| if index < 0: | |
| must_stop = True | |
| break | |
| indices.append(index) | |
| examples.append(example) | |
| except Exception: | |
| pass | |
| if len(examples) == 0: | |
| continue | |
| self.send_results(predictor, indices, examples) | |