Skip to content

Commit

Permalink
test: preliminary tests and merge fix for authv2 (#3584)
Browse files Browse the repository at this point in the history
* add api key to existing app tests, add preliminary auth test

Signed-off-by: Dave Lee <[email protected]>

* small fix, run test

Signed-off-by: Dave Lee <[email protected]>

* status on non-opaque

Signed-off-by: Dave Lee <[email protected]>

* tweak auth error

Signed-off-by: Dave Lee <[email protected]>

* exp

Signed-off-by: Dave Lee <[email protected]>

* quick fix on real laptop

Signed-off-by: Dave Lee <[email protected]>

* add downloader version that allows providing an auth header

Signed-off-by: Dave Lee <[email protected]>

* stash some devcontainer fixes during testing

Signed-off-by: Dave Lee <[email protected]>

* s2

Signed-off-by: Dave Lee <[email protected]>

* s

Signed-off-by: Dave Lee <[email protected]>

* done with experiment

Signed-off-by: Dave Lee <[email protected]>

* done with experiment

Signed-off-by: Dave Lee <[email protected]>

* after merge fix

Signed-off-by: Dave Lee <[email protected]>

* rename and fix

Signed-off-by: Dave Lee <[email protected]>

---------

Signed-off-by: Dave Lee <[email protected]>
Co-authored-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
dave-gray101 and mudler authored Sep 24, 2024
1 parent 69d2902 commit 90cacb9
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 41 deletions.
2 changes: 2 additions & 0 deletions .devcontainer-scripts/utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Param 2: email
#
config_user() {
echo "Configuring git for $1 <$2>"
local gcn=$(git config --global user.name)
if [ -z "${gcn}" ]; then
echo "Setting up git user / remote"
Expand All @@ -24,6 +25,7 @@ config_user() {
# Param 2: remote url
#
config_remote() {
echo "Adding git remote and fetching $2 as $1"
local gr=$(git remote -v | grep $1)
if [ -z "${gr}" ]; then
git remote add $1 $2
Expand Down
5 changes: 2 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,8 @@ RUN if [ "${FFMPEG}" = "true" ]; then \

RUN apt-get update && \
apt-get install -y --no-install-recommends \
ssh less && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
ssh less wget
# For the devcontainer, leave apt functional in case additional devtools are needed at runtime.

RUN go install github.com/go-delve/delve/cmd/dlv@latest

Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ clean-tests:
rm -rf test-dir
rm -rf core/http/backend-assets

clean-dc: clean
cp -r /build/backend-assets /workspace/backend-assets

## Build:
build: prepare backend-assets grpcs ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
Expand Down
4 changes: 2 additions & 2 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func AvailableGalleryModels(galleries []config.Gallery, basePath string) ([]*Gal
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -156,7 +156,7 @@ func getGalleryModels(gallery config.Gallery, basePath string) ([]*GalleryModel,
}
uri := downloader.URI(gallery.URL)

err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion core/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type PromptTemplate struct {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error {
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down
18 changes: 0 additions & 18 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,6 @@ import (
"github.com/rs/zerolog/log"
)

func readAuthHeader(c *fiber.Ctx) string {
authHeader := c.Get("Authorization")

// elevenlabs
xApiKey := c.Get("xi-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}

// anthropic
xApiKey = c.Get("x-api-key")
if xApiKey != "" {
authHeader = "Bearer " + xApiKey
}

return authHeader
}

// Embed a directory
//
//go:embed static/*
Expand Down
69 changes: 62 additions & 7 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ import (
"github.com/sashabaranov/go-openai/jsonschema"
)

const apiKey = "joshua"
const bearerKey = "Bearer " + apiKey

const testPrompt = `### System:
You are an AI assistant that follows instruction extremely well. Help as much as you can.
Expand All @@ -50,11 +53,19 @@ type modelApplyRequest struct {

func getModelStatus(url string) (response map[string]interface{}) {
// Create the HTTP request
resp, err := http.Get(url)
req, err := http.NewRequest("GET", url, nil)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)
if err != nil {
fmt.Println("Error creating request:", err)
return
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Println("Error sending request:", err)
return
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
Expand All @@ -72,14 +83,15 @@ func getModelStatus(url string) (response map[string]interface{}) {
return
}

func getModels(url string) (response []gallery.GalleryModel) {
func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
return
return response, err
}

func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) {
Expand All @@ -101,6 +113,7 @@ func postModelApplyRequest(url string, request modelApplyRequest) (response map[
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

// Make the request
client := &http.Client{}
Expand Down Expand Up @@ -140,6 +153,7 @@ func postRequestJSON[B any](url string, bodyJson *B) error {
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand Down Expand Up @@ -175,6 +189,7 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", bearerKey)

client := &http.Client{}
resp, err := client.Do(req)
Expand All @@ -195,6 +210,35 @@ func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *
return json.Unmarshal(body, respJson)
}

func postInvalidRequest(url string) (error, int) {

req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request"))
if err != nil {
return err, -1
}

req.Header.Set("Content-Type", "application/json")

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err, -1
}

defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return err, -1
}

if resp.StatusCode < 200 || resp.StatusCode >= 400 {
return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode
}

return nil, resp.StatusCode
}

//go:embed backend-assets/*
var backendAssets embed.FS

Expand Down Expand Up @@ -260,6 +304,7 @@ var _ = Describe("API test", func() {
config.WithContext(c),
config.WithGalleries(galleries),
config.WithModelPath(modelDir),
config.WithApiKeys([]string{apiKey}),
config.WithBackendAssets(backendAssets),
config.WithBackendAssetsOutput(backendAssetsDir))...)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -269,7 +314,7 @@ var _ = Describe("API test", func() {

go app.Listen("127.0.0.1:9090")

defaultConfig := openai.DefaultConfig("")
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"

client2 = openaigo.NewClient("")
Expand All @@ -295,10 +340,19 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred())
})

Context("Auth Tests", func() {
It("Should fail if the api key is missing", func() {
err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available")
Expect(err).ToNot(BeNil())
Expect(sc).To(Equal(403))
})
})

Context("Applying models", func() {

It("applies models from a gallery", func() {
models := getModels("http://127.0.0.1:9090/models/available")
models, err := getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models))
Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models))
Expand Down Expand Up @@ -331,7 +385,8 @@ var _ = Describe("API test", func() {
Expect(content["backend"]).To(Equal("bert-embeddings"))
Expect(content["foo"]).To(Equal("bar"))

models = getModels("http://127.0.0.1:9090/models/available")
models, err = getModels("http://127.0.0.1:9090/models/available")
Expect(err).To(BeNil())
Expect(len(models)).To(Equal(2), fmt.Sprint(models))
Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2")))
Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2")))
Expand Down
3 changes: 2 additions & 1 deletion core/http/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(403)
}
return ctx.Status(403).SendString(err.Error())
}
if applicationConfig.OpaqueErrors {
return ctx.SendStatus(500)
Expand Down Expand Up @@ -90,4 +91,4 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
}
}
return func(c *fiber.Ctx) bool { return false }
}
}
2 changes: 1 addition & 1 deletion embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func init() {
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{}
uri := downloader.URI(url)
err := uri.DownloadAndUnmarshal(basePath, func(_ string, i []byte) error {
err := uri.DownloadWithCallback(basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module github.com/mudler/LocalAI

go 1.22.0
go 1.23

toolchain go1.22.4
toolchain go1.23.1

require (
dario.cat/mergo v1.0.0
Expand Down
18 changes: 15 additions & 3 deletions pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ const (

type URI string

func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte) error) error {
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
}

func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()

if strings.HasPrefix(url, LocalPrefix) {
Expand All @@ -41,7 +45,6 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
if err != nil {
return err
}
// ???
resolvedBasePath, err := filepath.EvalSymlinks(basePath)
if err != nil {
return err
Expand All @@ -63,7 +66,16 @@ func (uri URI) DownloadAndUnmarshal(basePath string, f func(url string, i []byte
}

// Send a GET request to the URL
response, err := http.Get(url)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
return err
}
if authorization != "" {
req.Header.Add("Authorization", authorization)
}

response, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/downloader/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with a branch", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml")
Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand All @@ -21,7 +21,7 @@ var _ = Describe("Gallery API tests", func() {
uri := URI("github:go-skynet/model-gallery/gpt4all-j.yaml@main")

Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand All @@ -30,7 +30,7 @@ var _ = Describe("Gallery API tests", func() {
It("parses github with urls", func() {
uri := URI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")
Expect(
uri.DownloadAndUnmarshal("", func(url string, i []byte) error {
uri.DownloadWithCallback("", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand Down

0 comments on commit 90cacb9

Please sign in to comment.