Nymbo commited on
Commit
3032e63
·
verified ·
1 Parent(s): e922c8c

Update TTS/tts/layers/xtts/gpt_inference.py

Browse files
TTS/tts/layers/xtts/gpt_inference.py CHANGED
@@ -2,11 +2,17 @@ import math
2
 
3
  import torch
4
  from torch import nn
5
- from transformers import GPT2PreTrainedModel
 
 
 
 
 
 
6
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
 
8
 
9
- class GPT2InferenceModel(GPT2PreTrainedModel):
10
  """Override GPT2LMHeadModel to allow for prefix conditioning."""
11
 
12
  def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
 
2
 
3
  import torch
4
  from torch import nn
5
+
6
+ try:
7
+ from transformers import GPT2PreTrainedModel, GenerationMixin
8
+ except ImportError: # pragma: no cover - new transformers layout
9
+ from transformers import GPT2PreTrainedModel # type: ignore
10
+ from transformers.generation.utils import GenerationMixin
11
+
12
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
13
 
14
 
15
+ class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
16
  """Override GPT2LMHeadModel to allow for prefix conditioning."""
17
 
18
  def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):