Compare commits

...

4 Commits

Author SHA1 Message Date
Eve
5c3d0f1824 ggml : IQ4_NL sgemm + Q4_0 AVX optimization (#9422)
* squashed

readd my iq4_nl sgemm PR https://github.com/ggerganov/llama.cpp/pull/8049

have ggml_vec_dot_q4_0 do two blocks per loop for avx

try out f16c ggml_vec_dot_iq4_nl, but it's not really faster. as per https://github.com/ggerganov/llama.cpp/pull/8549 we can calculate several blocks at a time with no issue

* shuffle

* remove f16c iq4_nl as i cant make it faster than before
2024-09-16 09:48:24 +03:00
Shane A
0aadac10c7 llama : support OLMoE (#9462) 2024-09-16 09:47:37 +03:00
CarryFun
95ca85168b llama : support MiniCPM3 (#9322)
Co-authored-by: 范睿凯 <fanruikai@modelbest.cn>
2024-09-16 09:45:20 +03:00
Vinesh Janarthanan
441b72b91f main : option to disable context shift (#9484)
* added cli arg to disable context shift

* reverted precommit

* updated README.md for main

* white space

* allow disabling context shift in the server

* Update common/arg.cpp

no-context-shift only works for main example

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* added server example to --no-context-shift args

* removed server changes

* white space

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2024-09-16 09:20:01 +03:00
11 changed files with 776 additions and 72 deletions

View File

@@ -77,6 +77,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [SEA-LION](https://huggingface.co/models?search=sea-lion)
- [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B)
- [x] [OLMo](https://allenai.org/olmo)
- [x] [OLMoE](https://huggingface.co/allenai/OLMoE-1B-7B-0924)
- [x] [Granite models](https://huggingface.co/collections/ibm-granite/granite-code-models-6624c5cec322e4c148c8b330)
- [x] [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) + [Pythia](https://github.com/EleutherAI/pythia)
- [x] [Snowflake-Arctic MoE](https://huggingface.co/collections/Snowflake/arctic-66290090abe542894a5ac520)

View File

@@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.n_keep = value;
}
));
add_opt(llama_arg(
{"--no-context-shift"},
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](gpt_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
return ctx_arg;
}

View File

@@ -246,6 +246,7 @@ struct gpt_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch

View File

@@ -1841,6 +1841,60 @@ class MiniCPMModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("MiniCPM3ForCausalLM")
class MiniCPM3Model(Model):
model_arch = gguf.MODEL_ARCH.MINICPM3
def set_gguf_parameters(self):
hparams = self.hparams
rope_dims = hparams["qk_rope_head_dim"]
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is None:
return
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
def set_vocab(self):
self._set_vocab_llama_hf()
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
@Model.register("QWenLMHeadModel")
class QwenModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN
@@ -2944,6 +2998,66 @@ class OlmoModel(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("OlmoeForCausalLM")
class OlmoeModel(Model):
model_arch = gguf.MODEL_ARCH.OLMOE
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
if (n_experts := self.hparams.get("num_experts")) is not None:
self.gguf_writer.add_expert_count(n_experts)
_experts: list[dict[str, Tensor]] | None = None
# Copied from: Qwen2MoeModel
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# process the experts separately
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]
self._experts[bid][name] = data_torch
if len(self._experts[bid]) >= n_experts * 3:
tensors: list[tuple[str, Tensor]] = []
# merge the experts into a single 3d tensor
for w_name in ["down_proj", "gate_proj", "up_proj"]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]
data_torch = torch.stack(datas, dim=0)
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
new_name = self.map_tensor_name(merged_name)
tensors.append((new_name, data_torch))
return tensors
else:
return []
return [(self.map_tensor_name(name), data_torch)]
# Copied from: Qwen2MoeModel
def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
# flatten `list[dict[str, Tensor]]` into `list[str]`
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")
@Model.register("JinaBertModel", "JinaBertForMaskedLM")
class JinaBertV2Model(BertModel):
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2

View File

@@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite
If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
### Temperature

View File

@@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("clear session path\n");
path_session.clear();
}
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("clear session path\n");
path_session.clear();
}
} else {
// context extension via Self-Extend

View File

@@ -230,6 +230,12 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
return _mm_packus_epi16( bytes1, bytes2);
}
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
@@ -4206,37 +4212,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
sumf = hsum_float_8(acc);
#elif defined(__AVX__)
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
const __m128i mone = _mm_set1_epi16(1);
// Main loop
for (; ib < nb; ++ib) {
// Compute combined scale for the block
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
__m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
const __m128i lowMask = _mm_set1_epi8(0xF);
const __m128i off = _mm_set1_epi8(8);
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[ib].qs);
__m128i bx_0 = _mm_and_si128(lowMask, tmp);
__m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs);
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by_0 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16));
bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
// Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
// Apply the scale, and accumulate
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8));
const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8));
const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8));
const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8));
const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0);
const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1);
const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0);
const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1);
const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone);
const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone);
const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone);
const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone);
accum1 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_1_1, p_1_0))), accum1);
accum2 = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)),
_mm256_cvtepi32_ps(MM256_SET_M128I(p_2_1, p_2_0))), accum2);
}
sumf = hsum_float_8(acc);
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
#elif defined(__SSSE3__)
// set constants
const __m128i lowMask = _mm_set1_epi8(0xF);
@@ -11819,15 +11825,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * restrict s, size_t bs, const void *
#endif
}
#if defined(__AVX__)
static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) {
const __m128i ax = _mm_sign_epi8(x, x);
const __m128i sy = _mm_sign_epi8(y, x);
return _mm_maddubs_epi16(ax, sy);
}
#endif
#if defined(__AVX2__)
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);

View File

@@ -235,6 +235,14 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
}
#endif // __AVX512F__
////////////////////////////////////////////////////////////////////////////////////////////////////
// CONSTANTS
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
@@ -933,6 +941,20 @@ class tinyBLAS_Q0_AVX {
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
}
inline __m256i load(const block_iq4_nl *b) {
return MM256_SET_M128I(load1(b), load0(b));
}
inline __m128i load0(const block_iq4_nl *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
}
inline __m128i load1(const block_iq4_nl *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
}
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@@ -1159,6 +1181,22 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
#endif
}
case GGML_TYPE_IQ4_NL: {
if (Btype != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
k, (const block_iq4_nl *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif
}
default:
return false;
}

View File

@@ -210,6 +210,7 @@ class MODEL_ARCH(IntEnum):
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
@@ -219,6 +220,7 @@ class MODEL_ARCH(IntEnum):
COMMAND_R = auto()
DBRX = auto()
OLMO = auto()
OLMOE = auto()
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
@@ -364,6 +366,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
@@ -373,6 +376,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.COMMAND_R: "command-r",
MODEL_ARCH.DBRX: "dbrx",
MODEL_ARCH.OLMO: "olmo",
MODEL_ARCH.OLMOE: "olmoe",
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
@@ -867,6 +871,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.MINICPM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GEMMA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@@ -1008,6 +1029,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.OLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
],
MODEL_ARCH.OPENELM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@@ -13,7 +13,7 @@ class TensorNameMap:
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron
"model.embed_tokens", # llama-hf nemotron olmoe
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"language_model.embedding.word_embeddings", # persimmon
@@ -54,7 +54,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
@@ -66,7 +66,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2
"model.norm", # llama-hf baichuan internlm2 olmoe
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
"ln_f", # refact bloom qwen gpt2
@@ -98,7 +98,7 @@ class TensorNameMap:
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
"transformer.h.{bid}.ln_mlp", # falcon40b
"model.layers.{bid}.input_layernorm", # llama-hf nemotron
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
"layers.{bid}.attention_norm", # llama-pth
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
"model.layers.{bid}.ln1", # yi
@@ -142,7 +142,7 @@ class TensorNameMap:
# Attention query
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wq", # llama-pth
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
@@ -154,7 +154,7 @@ class TensorNameMap:
# Attention key
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wk", # llama-pth
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
@@ -167,7 +167,7 @@ class TensorNameMap:
# Attention value
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wv", # llama-pth
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
@@ -185,7 +185,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
@@ -229,7 +229,7 @@ class TensorNameMap:
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
"layers.{bid}.ffn_norm", # llama-pth
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
@@ -253,7 +253,7 @@ class TensorNameMap:
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.mlp.gate", # qwen2moe
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
),
@@ -295,7 +295,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
),
MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -327,7 +327,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -367,7 +367,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe (merged)
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@@ -378,7 +378,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
"model.layers.{bid}.self_attn.q_norm", # cohere
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
"transformer.layers.{bid}.attn.q_norm", # openelm
@@ -387,7 +387,7 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
"model.layers.{bid}.self_attn.k_norm", # cohere
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
"transformer.layers.{bid}.attn.k_norm", # openelm

View File

@@ -193,6 +193,7 @@ enum llm_arch {
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
LLM_ARCH_MINICPM,
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_STARCODER2,
@@ -201,6 +202,7 @@ enum llm_arch {
LLM_ARCH_COMMAND_R,
LLM_ARCH_DBRX,
LLM_ARCH_OLMO,
LLM_ARCH_OLMOE,
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK2,
@@ -241,6 +243,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
{ LLM_ARCH_MINICPM, "minicpm" },
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
@@ -249,6 +252,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_DBRX, "dbrx" },
{ LLM_ARCH_OLMO, "olmo" },
{ LLM_ARCH_OLMOE, "olmoe" },
{ LLM_ARCH_OPENELM, "openelm" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
@@ -1034,6 +1038,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
},
},
{
LLM_ARCH_MINICPM3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{
LLM_ARCH_GEMMA,
{
@@ -1168,6 +1195,26 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_OLMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_OPENELM,
{
@@ -2252,6 +2299,7 @@ enum e_model {
MODEL_MEDIUM,
MODEL_LARGE,
MODEL_XL,
MODEL_A1_7B,
MODEL_A2_7B,
MODEL_8x7B,
MODEL_8x22B,
@@ -5216,6 +5264,7 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_MEDIUM: return "0.4B";
case MODEL_LARGE: return "0.8B";
case MODEL_XL: return "1.5B";
case MODEL_A1_7B: return "A1.7B";
case MODEL_A2_7B: return "A2.7B";
case MODEL_8x7B: return "8x7B";
case MODEL_8x22B: return "8x22B";
@@ -5390,6 +5439,17 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MINICPM3:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
switch (hparams.n_layer) {
case 62: model.type = e_model::MODEL_4B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GROK:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -5755,6 +5815,14 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OLMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 16: model.type = e_model::MODEL_A1_7B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OPENELM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -6897,6 +6965,54 @@ static bool llm_load_tensors(
}
}
} break;
case LLM_ARCH_MINICPM3:
{
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const int64_t q_lora_rank = hparams.n_lora_q;
const int64_t kv_lora_rank = hparams.n_lora_kv;
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (model.output == NULL) {
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
}
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
}
} break;
case LLM_ARCH_GROK:
{
if (n_expert == 0) {
@@ -7934,6 +8050,44 @@ static bool llm_load_tensors(
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
}
} break;
case LLM_ARCH_OLMOE:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd});
layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
GGML_ASSERT(n_expert > 0);
GGML_ASSERT(n_expert_used > 0);
// MoE branch
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert});
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
}
} break;
case LLM_ARCH_OPENELM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -12843,6 +12997,215 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_minicpm3() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
//TODO: if the model varies, these parameters need to be read from the model
const int64_t n_embd_base = 256;
const float scale_embd = 12.0f;
const float scale_depth = 1.4f;
const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// scale the input embeddings
inpL = ggml_scale(ctx0, inpL, scale_embd);
cb(inpL, "inp_scaled", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
struct ggml_tensor * rope_factors = build_rope_factors(il);
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self_attention
{
struct ggml_tensor * q = NULL;
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
cb(q, "q", il);
q = llm_build_norm(ctx0, q, hparams,
model.layers[il].attn_q_a_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(q, "q", il);
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
cb(q, "q", il);
// split into {n_head * n_embd_head_qk_nope, n_tokens}
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
0);
cb(q_nope, "q_nope", il);
// and {n_head * n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
ggml_row_size(q->type, hparams.n_embd_head_k),
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
ggml_row_size(q->type, n_embd_head_qk_nope));
cb(q_pe, "q_pe", il);
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
// split into {kv_lora_rank, n_tokens}
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
kv_pe_compresseed->nb[1],
0);
cb(kv_compressed, "kv_compressed", il);
// and {n_embd_head_qk_rope, n_tokens}
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
kv_pe_compresseed->nb[1],
kv_pe_compresseed->nb[1],
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
cb(k_pe, "k_pe", il);
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
model.layers[il].attn_kv_a_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(kv_compressed, "kv_compressed", il);
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
cb(kv, "kv", il);
// split into {n_head * n_embd_head_qk_nope, n_tokens}
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
0);
cb(k_nope, "k_nope", il);
// and {n_head * n_embd_head_v, n_tokens}
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
cb(v_states, "v_states", il);
v_states = ggml_cont(ctx0, v_states);
cb(v_states, "v_states", il);
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
0);
cb(v_states, "v_states", il);
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
// shared RoPE key
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(q_states, "q_states", il);
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(k_states, "k_states", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled", il);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
// scale the hidden states for residual connection
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled_ffn", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head scaling
const float scale_lmhead = float(n_embd_base)/float(n_embd);
cur = ggml_scale(ctx0, cur, scale_lmhead);
cb(cur, "lmhead_scaling", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_gemma() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -13539,6 +13902,134 @@ struct llm_build_context {
return gf;
}
// based on the build_qwen2moe() function, changes:
// * removed shared experts
// * removed bias
// * added q, k norm
struct ggml_cgraph * build_olmoe() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
// mutable variable, needed during the last layer of the computation to skip unused tokens
int32_t n_tokens = this->n_tokens;
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self_attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur_rope", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// MoE branch
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_moe_ffn(ctx0, lctx, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_SILU, false,
false, 0.0,
cb, il);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_openelm() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
@@ -15383,6 +15874,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_minicpm();
} break;
case LLM_ARCH_MINICPM3:
{
result = llm.build_minicpm3();
} break;
case LLM_ARCH_GEMMA:
{
result = llm.build_gemma();
@@ -15415,6 +15910,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_olmo();
} break;
case LLM_ARCH_OLMOE:
{
result = llm.build_olmoe();
} break;
case LLM_ARCH_OPENELM:
{
result = llm.build_openelm();
@@ -18599,6 +19098,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_QWEN:
case LLM_ARCH_QWEN2:
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2:
case LLM_ARCH_PHI3:
case LLM_ARCH_GEMMA:
@@ -18609,6 +19109,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_CODESHELL:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
case LLM_ARCH_MINICPM3:
return LLAMA_ROPE_TYPE_NEOX;
// all model arches should be listed explicitly here