diff --git a/clip.hpp b/clip.hpp index 2a6b08c0d..546704c8b 100644 --- a/clip.hpp +++ b/clip.hpp @@ -553,10 +553,9 @@ class CLIPEmbeddings : public GGMLBlock { void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { enum ggml_type token_wtype = GGML_TYPE_F32; if (!force_clip_f32) { - auto tensor_type = tensor_types.find(prefix + "token_embedding.weight"); - std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; - if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) { - token_wtype = tensor_type->second; + token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32); + if (!support_get_rows(token_wtype)) { + token_wtype = GGML_TYPE_F32; } } enum ggml_type position_wtype = GGML_TYPE_F32; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index a5f61ea46..1736ed110 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1967,13 +1967,24 @@ class Linear : public UnaryBlock { } }; +__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) { + std::set allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; + if (allow_types.find(wtype) != allow_types.end()) { + return true; + } + return false; +} + class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); - params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); + if (!support_get_rows(wtype)) { + wtype = GGML_TYPE_F32; + } + params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); } public: