Skip to content

Add Mixtral-8x22B-v0.1 model support #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ Below is a list of all the supported models via `BaseModel` class of `xTuring` a
|GPT-2 | gpt2|
|LlaMA | llama|
|LlaMA2 | llama2|
|Mixtral-8x22B | mixtral|
|OPT-1.3B | opt|

The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions.
Expand Down
9 changes: 5 additions & 4 deletions docs/docs/overview/quickstart/test.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ const modelList = {
cerebras: 'Cerebras',
distilgpt2: 'DistilGPT-2',
galactica: 'Galactica',
gptj: 'GPT-J',
gptj: 'GPT-J',
gpt2: 'GPT-2',
llama: 'LLaMA',
llama2: 'LLaMA 2',
opt: 'OPT',
mixtral: 'Mixtral',
}

export default function Test(
Expand All @@ -37,7 +38,7 @@ export default function Test(
} else {
finalKey = `${code.model}_${code.technique}`
}

useEffect(() => {
setCode({
model: 'llama',
Expand Down Expand Up @@ -92,8 +93,8 @@ from xturing.models import BaseModel
dataset = ${instruction}Dataset('...')

# Load the model
model = BaseModel.create('${finalKey}')`}
model = BaseModel.create('${finalKey}')`}
/>
</div>
)
}
}
3 changes: 2 additions & 1 deletion docs/docs/overview/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description: Models Supported by xTuring

<!-- # Models supported by xTuring -->
## Base versions
| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 |
| Model | Model Key | LoRA | INT8 | LoRA + INT8 | LoRA + INT4 |
| ------ | --- | :---: | :---: | :---: | :---: |
| BLOOM 1.1B| bloom | ✅ | ✅ | ✅ | ✅ |
| Cerebras 1.3B| cerebras | ✅ | ✅ | ✅ | ✅ |
Expand All @@ -18,6 +18,7 @@ description: Models Supported by xTuring
| LLaMA 7B | llama | ✅ | ✅ | ✅ | ✅ |
| LLaMA2 | llama2 | ✅ | ✅ | ✅ | ✅ |
| OPT 1.3B | opt | ✅ | ✅ | ✅ | ✅ |
| Mixtral-8x22 | mixtral | ✅ | ✅ | ✅ | |

### Memory-efficient versions
> The above mentioned are the base variants of the LLMs. Below are the templates to get their `LoRA`, `INT8`, `INT8 + LoRA` and `INT4 + LoRA` versions.
Expand Down
14 changes: 14 additions & 0 deletions examples/models/mixtral/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from xturing.datasets.instruction_dataset import InstructionDataset
from xturing.models import BaseModel

instruction_dataset = InstructionDataset("./alpaca_data")

# Initialize the model
model = BaseModel.create("mixtral")

# Fine-tune the model
model.finetune(dataset=instruction_dataset)

# Once the model has been fine-tuned, you can start doing inferences
output = model.generate(texts=["Why LLM models are becoming so important?"])
print("Generated output by the model: {}".format(output))
26 changes: 26 additions & 0 deletions src/xturing/config/finetuning_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,32 @@ mamba:
learning_rate: 5e-5
weight_decay: 0.01

mixtral:
learning_rate: 5e-5
weight_decay: 0.01
num_train_epochs: 3
batch_size: 1

mixtral_lora:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 4

mixtral_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

mixtral_lora_int8:
learning_rate: 1e-4
weight_decay: 0.01
num_train_epochs: 3
batch_size: 8
max_length: 256

opt:
learning_rate: 5e-5
weight_decay: 0.01
Expand Down
24 changes: 24 additions & 0 deletions src/xturing/config/generation_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ llama2_lora_kbit:
mamba:
do_sample: false

# Contrastive search
mixtral:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

# Contrastive search
mixtral_lora:
penalty_alpha: 0.6
top_k: 4
max_new_tokens: 256
do_sample: false

# Greedy search
mixtral_int8:
max_new_tokens: 256
do_sample: false

# Greedy search
mixtral_lora_int8:
max_new_tokens: 256
do_sample: false

# Contrastive search
opt:
penalty_alpha: 0.6
Expand Down
10 changes: 10 additions & 0 deletions src/xturing/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
LlamaLoraKbitEngine,
)
from xturing.engines.mamba_engine import MambaEngine
from xturing.engines.mixtral_engine import (
MixtralEngine,
MixtralInt8Engine,
MixtralLoraEngine,
MixtralLoraInt8Engine,
)
from xturing.engines.opt_engine import (
OPTEngine,
OPTInt8Engine,
Expand Down Expand Up @@ -109,6 +115,10 @@
BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine)
BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine)
BaseEngine.add_to_registry(MambaEngine.config_name, MambaEngine)
BaseEngine.add_to_registry(MixtralEngine.config_name, MixtralEngine)
BaseEngine.add_to_registry(MixtralInt8Engine.config_name, MixtralInt8Engine)
BaseEngine.add_to_registry(MixtralLoraEngine.config_name, MixtralLoraEngine)
BaseEngine.add_to_registry(MixtralLoraInt8Engine.config_name, MixtralLoraInt8Engine)
BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine)
BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine)
BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine)
Expand Down
63 changes: 63 additions & 0 deletions src/xturing/engines/mixtral_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from pathlib import Path
from typing import Optional, Union

from xturing.engines.causal import CausalEngine, CausalLoraEngine


class MixtralEngine(CausalEngine):
config_name: str = "mixtral_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
super().__init__(
model_name="mistral-community/Mixtral-8x22B-v0.1",
weights_path=weights_path,
trust_remote_code=True,
)

self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id


class MixtralLoraEngine(CausalLoraEngine):
config_name: str = "mixtral_lora_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
super().__init__(
model_name="mistral-community/Mixtral-8x22B-v0.1",
weights_path=weights_path,
target_modules=["q_proj", "v_proj"],
trust_remote_code=True,
)

self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id


class MixtralInt8Engine(CausalEngine):
config_name: str = "mixtral_int8_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
super().__init__(
model_name="mistral-community/Mixtral-8x22B-v0.1",
weights_path=weights_path,
load_8bit=True,
trust_remote_code=True,
)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id


class MixtralLoraInt8Engine(CausalLoraEngine):
config_name: str = "mixtral_lora_int8_engine"

def __init__(self, weights_path: Optional[Union[str, Path]] = None):
super().__init__(
model_name="mistral-community/Mixtral-8x22B-v0.1",
weights_path=weights_path,
load_8bit=True,
target_modules=["q_proj", "v_proj"],
trust_remote_code=True,
)

self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
5 changes: 5 additions & 0 deletions src/xturing/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Llama2LoraKbit,
)
from xturing.models.mamba import Mamba
from xturing.models.mixtral import Mixtral, MixtralInt8, MixtralLora, MixtralLoraInt8
from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8
from xturing.models.stable_diffusion import StableDiffusion

Expand Down Expand Up @@ -90,6 +91,10 @@
BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8)
BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit)
BaseModel.add_to_registry(Mamba.config_name, Mamba)
BaseModel.add_to_registry(Mixtral.config_name, Mixtral)
BaseModel.add_to_registry(MixtralInt8.config_name, MixtralInt8)
BaseModel.add_to_registry(MixtralLora.config_name, MixtralLora)
BaseModel.add_to_registry(MixtralLoraInt8.config_name, MixtralLoraInt8)
BaseModel.add_to_registry(OPT.config_name, OPT)
BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8)
BaseModel.add_to_registry(OPTLora.config_name, OPTLora)
Expand Down
42 changes: 42 additions & 0 deletions src/xturing/models/mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Optional

from xturing.engines.mixtral_engine import (
MixtralEngine,
MixtralInt8Engine,
MixtralLoraEngine,
MixtralLoraInt8Engine,
)
from xturing.models.causal import (
CausalInt8Model,
CausalLoraInt8Model,
CausalLoraModel,
CausalModel,
)


class Mixtral(CausalModel):
config_name: str = "mixtral"

def __init__(self, weights_path: Optional[str] = None):
super().__init__(MixtralEngine.config_name, weights_path)


class MixtralLora(CausalLoraModel):
config_name: str = "mixtral_lora"

def __init__(self, weights_path: Optional[str] = None):
super().__init__(MixtralLoraEngine.config_name, weights_path)


class MixtralInt8(CausalInt8Model):
config_name: str = "mixtral_int8"

def __init__(self, weights_path: Optional[str] = None):
super().__init__(MixtralInt8Engine.config_name, weights_path)


class MixtralLoraInt8(CausalLoraInt8Model):
config_name: str = "mixtral_lora_int8"

def __init__(self, weights_path: Optional[str] = None):
super().__init__(MixtralLoraInt8Engine.config_name, weights_path)