| def load_stt_model(sst_model_id: str, sst_adapter_id: str | None, sst_pipeline_kwargs: dict) -> Pipeline:
|
| model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| sst_model_id,
|
| torch_dtype=torch_dtype,
|
| low_cpu_mem_usage=True,
|
| use_safetensors=True)
|
| if sst_adapter_id:
|
| print(f"Loading adapter {sst_adapter_id} for model {sst_model_id}")
|
| model.load_adapter(sst_adapter_id, "adapter")
|
| model.set_adapter("adapter")
|
| model.to(device)
|
|
|
| processor = AutoProcessor.from_pretrained(sst_model_id)
|
| # https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.AutomaticSpeechRecognitionPipeline
|
| return pipeline(
|
| "automatic-speech-recognition",
|
| model=model,
|
| tokenizer=processor.tokenizer,
|
| feature_extractor=processor.feature_extractor,
|
| torch_dtype=torch_dtype,
|
| device=device,
|
| batch_size=1,
|
| **sst_pipeline_kwargs,
|
| )
|