diff --git a/ai/ai.go b/ai/ai.go index 045df5b..f087ed2 100644 --- a/ai/ai.go +++ b/ai/ai.go @@ -11,6 +11,7 @@ import ( type API interface { Type() string Model() string + WithModel(name string) API Generate(system, prompt string, ctx json.RawMessage) (string, error) } @@ -43,6 +44,12 @@ func (a *OllamaAPI) Model() string { return a.ModelName } +func (a *OllamaAPI) WithModel(model string) API { + a2 := *a + a2.ModelName = model + return &a2 +} + func (a *OllamaAPI) Generate(system, prompt string, ctxU32s json.RawMessage) (string, error) { var context []uint32 // for type safety while maintaining interface @@ -55,10 +62,10 @@ func (a *OllamaAPI) Generate(system, prompt string, ctxU32s json.RawMessage) (st Context: context, Prompt: prompt, Stream: false, - Options: &Options{ - Temperature: 0.7, // Controls randomness (0.0 to 1.0) - TopP: 0.8, // Controls diversity (0.0 to 1.0) - }, + // Options: &Options{ + // Temperature: 0.7, // Controls randomness (0.0 to 1.0) + // TopP: 0.8, // Controls diversity (0.0 to 1.0) + // }, } jsonData, _ := json.Marshal(reqBody) @@ -162,6 +169,12 @@ type OpenAIResponse struct { } `json:"choices"` } +func (a *OpenAiAPI) WithModel(model string) API { + a2 := *a + a2.ModelName = model + return &a2 +} + func (a *OpenAiAPI) Generate(system, prompt string, ctxMessages json.RawMessage) (string, error) { reqBody := OpenAIRequest{ Model: a.ModelName, // Default OpenAI model, adjust as needed @@ -169,9 +182,9 @@ func (a *OpenAiAPI) Generate(system, prompt string, ctxMessages json.RawMessage) {Role: "system", Content: system}, {Role: "user", Content: prompt}, }, - Stream: false, - Temperature: 0.7, - TopP: 0.9, + Stream: false, + // Temperature: 0.7, + // TopP: 0.9, } jsonData, _ := json.Marshal(reqBody)