diff --git a/bindings/go/whisper.go b/bindings/go/whisper.go index 3ef73414d90..78a92eba8f7 100644 --- a/bindings/go/whisper.go +++ b/bindings/go/whisper.go @@ -16,6 +16,13 @@ import ( #include #include +// Grammar element type +typedef struct whisper_grammar_element whisper_grammar_element_t; + +// VAD context +typedef struct whisper_vad_context whisper_vad_context_t; +typedef struct whisper_vad_segments whisper_vad_segments_t; + extern void callNewSegment(void* user_data, int new); extern void callProgress(void* user_data, int progress); extern bool callEncoderBegin(void* user_data); @@ -71,6 +78,12 @@ type ( TokenData C.struct_whisper_token_data SamplingStrategy C.enum_whisper_sampling_strategy Params C.struct_whisper_full_params + ContextParams C.struct_whisper_context_params + VADContext C.struct_whisper_vad_context + VADSegments C.struct_whisper_vad_segments + VADParams C.struct_whisper_vad_params + GrammarElement C.struct_whisper_grammar_element + GrammarType C.enum_whisper_gretype ) /////////////////////////////////////////////////////////////////////////////// @@ -81,6 +94,16 @@ const ( SAMPLING_BEAM_SEARCH SamplingStrategy = C.WHISPER_SAMPLING_BEAM_SEARCH ) +const ( + GRAMMAR_END GrammarType = C.WHISPER_GRETYPE_END + GRAMMAR_ALT GrammarType = C.WHISPER_GRETYPE_ALT + GRAMMAR_RULE_REF GrammarType = C.WHISPER_GRETYPE_RULE_REF + GRAMMAR_CHAR GrammarType = C.WHISPER_GRETYPE_CHAR + GRAMMAR_CHAR_NOT GrammarType = C.WHISPER_GRETYPE_CHAR_NOT + GRAMMAR_CHAR_RNG_UPPER GrammarType = C.WHISPER_GRETYPE_CHAR_RNG_UPPER + GRAMMAR_CHAR_ALT GrammarType = C.WHISPER_GRETYPE_CHAR_ALT +) + const ( SampleRate = C.WHISPER_SAMPLE_RATE // Expected sample rate, samples per second SampleBits = uint16(unsafe.Sizeof(C.float(0))) * 8 // Sample size in bits @@ -398,6 +421,378 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 { return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token))) } +/////////////////////////////////////////////////////////////////////////////// +// CONTEXT PARAMS + +// Get default context parameters +func Whisper_context_default_params() ContextParams { + return ContextParams(C.whisper_context_default_params()) +} + +// Initialize context from file with parameters +func Whisper_init_from_file_with_params(path string, params ContextParams) *Context { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + if ctx := C.whisper_init_from_file_with_params(cPath, (C.struct_whisper_context_params)(params)); ctx != nil { + return (*Context)(ctx) + } + return nil +} + +// Initialize context from buffer with parameters +func Whisper_init_from_buffer_with_params(buffer []byte, params ContextParams) *Context { + if ctx := C.whisper_init_from_buffer_with_params(unsafe.Pointer(&buffer[0]), C.size_t(len(buffer)), (C.struct_whisper_context_params)(params)); ctx != nil { + return (*Context)(ctx) + } + return nil +} + +/////////////////////////////////////////////////////////////////////////////// +// VAD (VOICE ACTIVITY DETECTION) + +// Get default VAD parameters +func Whisper_vad_default_params() VADParams { + return VADParams(C.whisper_vad_default_params()) +} + +// Get default VAD context parameters +func Whisper_vad_default_context_params() C.struct_whisper_vad_context_params { + return C.whisper_vad_default_context_params() +} + +// Initialize VAD context from file with default parameters +func Whisper_vad_init_from_file(path string) *VADContext { + params := Whisper_vad_default_context_params() + return Whisper_vad_init_from_file_with_params_struct(path, params) +} + +// Initialize VAD context from file with struct parameters +func Whisper_vad_init_from_file_with_params_struct(path string, params C.struct_whisper_vad_context_params) *VADContext { + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + if vctx := C.whisper_vad_init_from_file_with_params(cPath, params); vctx != nil { + return (*VADContext)(vctx) + } + return nil +} + +// Initialize VAD context from file with parameters +func Whisper_vad_init_from_file_with_params(path string, nThreads int, useGPU bool, gpuDevice int) *VADContext { + var cparams C.struct_whisper_vad_context_params + cparams.n_threads = C.int(nThreads) + cparams.use_gpu = C.bool(useGPU) + cparams.gpu_device = C.int(gpuDevice) + + return Whisper_vad_init_from_file_with_params_struct(path, cparams) +} + +// Free VAD context +func (vctx *VADContext) Whisper_vad_free() { + C.whisper_vad_free((*C.struct_whisper_vad_context)(vctx)) +} + +// Detect speech in audio samples +func (vctx *VADContext) Whisper_vad_detect_speech(samples []float32) bool { + return bool(C.whisper_vad_detect_speech( + (*C.struct_whisper_vad_context)(vctx), + (*C.float)(&samples[0]), + C.int(len(samples)), + )) +} + +// Get number of probabilities +func (vctx *VADContext) Whisper_vad_n_probs() int { + return int(C.whisper_vad_n_probs((*C.struct_whisper_vad_context)(vctx))) +} + +// Get probabilities array +func (vctx *VADContext) Whisper_vad_probs() []float32 { + n := vctx.Whisper_vad_n_probs() + probs := C.whisper_vad_probs((*C.struct_whisper_vad_context)(vctx)) + return (*[1 << 30]float32)(unsafe.Pointer(probs))[:n:n] +} + +// Get VAD segments from samples +func (vctx *VADContext) Whisper_vad_segments_from_samples(params VADParams, samples []float32) *VADSegments { + segments := C.whisper_vad_segments_from_samples( + (*C.struct_whisper_vad_context)(vctx), + (C.struct_whisper_vad_params)(params), + (*C.float)(&samples[0]), + C.int(len(samples)), + ) + return (*VADSegments)(segments) +} + +// Get VAD segments from pre-computed probabilities +func (vctx *VADContext) Whisper_vad_segments_from_probs(params VADParams) *VADSegments { + segments := C.whisper_vad_segments_from_probs( + (*C.struct_whisper_vad_context)(vctx), + (C.struct_whisper_vad_params)(params), + ) + return (*VADSegments)(segments) +} + +// Get number of segments +func (segments *VADSegments) Whisper_vad_segments_n_segments() int { + return int(C.whisper_vad_segments_n_segments((*C.struct_whisper_vad_segments)(segments))) +} + +// Get segment start time +func (segments *VADSegments) Whisper_vad_segments_get_segment_t0(i int) float32 { + return float32(C.whisper_vad_segments_get_segment_t0((*C.struct_whisper_vad_segments)(segments), C.int(i))) +} + +// Get segment end time +func (segments *VADSegments) Whisper_vad_segments_get_segment_t1(i int) float32 { + return float32(C.whisper_vad_segments_get_segment_t1((*C.struct_whisper_vad_segments)(segments), C.int(i))) +} + +// Free VAD segments +func (segments *VADSegments) Whisper_vad_free_segments() { + C.whisper_vad_free_segments((*C.struct_whisper_vad_segments)(segments)) +} + +/////////////////////////////////////////////////////////////////////////////// +// PARAMS BUILDER + +// ParamsBuilder provides a fluent API for building whisper_full_params +type ParamsBuilder struct { + params Params +} + +// NewParamsBuilder creates a new params builder with defaults +func NewParamsBuilder(ctx *Context, strategy SamplingStrategy) *ParamsBuilder { + return &ParamsBuilder{ + params: ctx.Whisper_full_default_params(strategy), + } +} + +// Build returns the final params +func (pb *ParamsBuilder) Build() Params { + return pb.params +} + +// SetThreads sets n_threads +func (pb *ParamsBuilder) SetThreads(n int) *ParamsBuilder { + pb.params.n_threads = C.int(n) + return pb +} + +// SetLanguage sets the language +func (pb *ParamsBuilder) SetLanguage(lang string) *ParamsBuilder { + pb.params.language = C.CString(lang) + return pb +} + +// SetTranslate sets translation mode +func (pb *ParamsBuilder) SetTranslate(translate bool) *ParamsBuilder { + pb.params.translate = C.bool(translate) + return pb +} + +// SetNoTimestamps disables timestamps +func (pb *ParamsBuilder) SetNoTimestamps(noTimestamps bool) *ParamsBuilder { + pb.params.no_timestamps = C.bool(noTimestamps) + return pb +} + +// SetTokenTimestamps enables token-level timestamps +func (pb *ParamsBuilder) SetTokenTimestamps(enabled bool) *ParamsBuilder { + pb.params.token_timestamps = C.bool(enabled) + return pb +} + +// SetSplitOnWord enables word-level splitting +func (pb *ParamsBuilder) SetSplitOnWord(enabled bool) *ParamsBuilder { + pb.params.split_on_word = C.bool(enabled) + return pb +} + +// SetMaxLen sets maximum segment length +func (pb *ParamsBuilder) SetMaxLen(maxLen int) *ParamsBuilder { + pb.params.max_len = C.int(maxLen) + return pb +} + +// SetAudioCtx sets audio context size +func (pb *ParamsBuilder) SetAudioCtx(audioCtx int) *ParamsBuilder { + pb.params.audio_ctx = C.int(audioCtx) + return pb +} + +// SetInitialPrompt sets the initial prompt +func (pb *ParamsBuilder) SetInitialPrompt(prompt string) *ParamsBuilder { + pb.params.initial_prompt = C.CString(prompt) + return pb +} + +// SetCarryInitialPrompt sets whether to carry initial prompt +func (pb *ParamsBuilder) SetCarryInitialPrompt(carry bool) *ParamsBuilder { + pb.params.carry_initial_prompt = C.bool(carry) + return pb +} + +// SetSuppressBlank sets suppress_blank +func (pb *ParamsBuilder) SetSuppressBlank(suppress bool) *ParamsBuilder { + pb.params.suppress_blank = C.bool(suppress) + return pb +} + +// SetSuppressNST sets suppress_nst (non-speech tokens) +func (pb *ParamsBuilder) SetSuppressNST(suppress bool) *ParamsBuilder { + pb.params.suppress_nst = C.bool(suppress) + return pb +} + +// SetSuppressRegex sets the regex pattern for token suppression +func (pb *ParamsBuilder) SetSuppressRegex(regex string) *ParamsBuilder { + pb.params.suppress_regex = C.CString(regex) + return pb +} + +// SetTemperature sets the sampling temperature +func (pb *ParamsBuilder) SetTemperature(temp float32) *ParamsBuilder { + pb.params.temperature = C.float(temp) + return pb +} + +// SetTemperatureInc sets the temperature increment +func (pb *ParamsBuilder) SetTemperatureInc(inc float32) *ParamsBuilder { + pb.params.temperature_inc = C.float(inc) + return pb +} + +// SetEntropyThold sets the entropy threshold +func (pb *ParamsBuilder) SetEntropyThold(thold float32) *ParamsBuilder { + pb.params.entropy_thold = C.float(thold) + return pb +} + +// SetLogprobThold sets the log probability threshold +func (pb *ParamsBuilder) SetLogprobThold(thold float32) *ParamsBuilder { + pb.params.logprob_thold = C.float(thold) + return pb +} + +// SetNoSpeechThold sets the no-speech threshold +func (pb *ParamsBuilder) SetNoSpeechThold(thold float32) *ParamsBuilder { + pb.params.no_speech_thold = C.float(thold) + return pb +} + +// SetGreedyBestOf sets best_of for greedy sampling +func (pb *ParamsBuilder) SetGreedyBestOf(bestOf int) *ParamsBuilder { + pb.params.greedy.best_of = C.int(bestOf) + return pb +} + +// SetBeamSize sets beam_size for beam search +func (pb *ParamsBuilder) SetBeamSize(beamSize int) *ParamsBuilder { + pb.params.beam_search.beam_size = C.int(beamSize) + return pb +} + +// SetTDRZ enables tinydiarize +func (pb *ParamsBuilder) SetTDRZ(enabled bool) *ParamsBuilder { + pb.params.tdrz_enable = C.bool(enabled) + return pb +} + +// SetDebugMode enables debug mode +func (pb *ParamsBuilder) SetDebugMode(enabled bool) *ParamsBuilder { + pb.params.debug_mode = C.bool(enabled) + return pb +} + +// SetGrammarPenalty sets the grammar penalty +func (pb *ParamsBuilder) SetGrammarPenalty(penalty float32) *ParamsBuilder { + pb.params.grammar_penalty = C.float(penalty) + return pb +} + +// SetVAD enables VAD with model path and parameters +func (pb *ParamsBuilder) SetVAD(enabled bool, modelPath string, vadParams VADParams) *ParamsBuilder { + pb.params.vad = C.bool(enabled) + if modelPath != "" { + pb.params.vad_model_path = C.CString(modelPath) + } + pb.params.vad_params = (C.struct_whisper_vad_params)(vadParams) + return pb +} + +/////////////////////////////////////////////////////////////////////////////// +// UTILITY FUNCTIONS + +// Get no_speech probability for a segment +func (ctx *Context) Whisper_full_get_segment_no_speech_prob(segment int) float32 { + return float32(C.whisper_full_get_segment_no_speech_prob((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Get speaker turn information +func (ctx *Context) Whisper_full_get_segment_speaker_turn_next(segment int) bool { + return bool(C.whisper_full_get_segment_speaker_turn_next((*C.struct_whisper_context)(ctx), C.int(segment))) +} + +// Model information functions +func (ctx *Context) Whisper_model_n_vocab() int { + return int(C.whisper_model_n_vocab((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_audio_ctx() int { + return int(C.whisper_model_n_audio_ctx((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_audio_state() int { + return int(C.whisper_model_n_audio_state((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_audio_head() int { + return int(C.whisper_model_n_audio_head((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_audio_layer() int { + return int(C.whisper_model_n_audio_layer((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_text_ctx() int { + return int(C.whisper_model_n_text_ctx((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_text_state() int { + return int(C.whisper_model_n_text_state((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_text_head() int { + return int(C.whisper_model_n_text_head((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_text_layer() int { + return int(C.whisper_model_n_text_layer((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_n_mels() int { + return int(C.whisper_model_n_mels((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_ftype() int { + return int(C.whisper_model_ftype((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_type() int { + return int(C.whisper_model_type((*C.struct_whisper_context)(ctx))) +} + +func (ctx *Context) Whisper_model_type_readable() string { + return C.GoString(C.whisper_model_type_readable((*C.struct_whisper_context)(ctx))) +} + +// Get whisper version +func Whisper_version() string { + return C.GoString(C.whisper_version()) +} + /////////////////////////////////////////////////////////////////////////////// // CALLBACKS