Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 109 additions & 49 deletions pkg/httpassert/extracter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package httpassert
import (
"fmt"
"reflect"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -21,68 +22,127 @@ type Extractor func(t *testing.T, actual any) any
// request.Expect().JsonPath("$.data.id", httpassert.ExtractTo(&id))
func ExtractTo(ptr any) Extractor {
return func(t *testing.T, actual any) any {
targetVal := reflect.ValueOf(ptr)
if targetVal.Kind() != reflect.Ptr || targetVal.IsNil() {
assert.Fail(t, "ExtractTo requires a non-nil pointer")
return extractInto(t, ptr, actual, nil)
}
}

func ExtractRegexTo(value string, ptr any) Extractor {
return func(t *testing.T, actual any) any {
return extractInto(t, ptr, actual, func(t *testing.T, v any, dstType reflect.Type) (any, bool) {
s, ok := v.(string)
if !ok {
assert.Fail(t, fmt.Sprintf("ExtractRegexTo expects actual to be string, got %T", v))
return nil, false
}

re := regexp.MustCompile(value)

m := re.FindStringSubmatch(s)
if m == nil {
assert.Fail(t, fmt.Sprintf("ExtractRegexTo no match for %q in %q", re.String(), s))
return nil, false
}

// m[0] full match, m[1:] groups
groups := m[1:]
if len(groups) == 0 {
groups = []string{m[0]}
}

if dstType.Kind() == reflect.Slice {
return groups, true
}
return groups[0], true
})
}
}

func extractInto(
t *testing.T,
ptr any,
actual any,
preprocess func(t *testing.T, actual any, dstType reflect.Type) (any, bool),
) any {
target := reflect.ValueOf(ptr)
if target.Kind() != reflect.Ptr || target.IsNil() {
assert.Fail(t, "ExtractTo requires a non-nil pointer")
return nil
}
if actual == nil {
assert.Fail(t, "ExtractTo actual value is nil")
return nil
}

dst := target.Elem()
dstType := dst.Type()

if preprocess != nil {
var ok bool
actual, ok = preprocess(t, actual, dstType)
if !ok {
return nil
}

if actual == nil {
assert.Fail(t, "ExtractTo actual value is nil")
return nil
}
}

outVal := reflect.ValueOf(actual)
targetType := targetVal.Elem().Type()
src := reflect.ValueOf(actual)

// Direct assign / convert for simple values (string, int, etc.)
if outVal.Type().AssignableTo(targetType) {
targetVal.Elem().Set(outVal)
return ptr
}
// unwrap interface{}
if src.IsValid() && src.Kind() == reflect.Interface && !src.IsNil() {
src = reflect.ValueOf(src.Interface())
}

if outVal.Type().ConvertibleTo(targetType) {
targetVal.Elem().Set(outVal.Convert(targetType))
return ptr
}
// direct assign / convert
if src.IsValid() && src.Type().AssignableTo(dstType) {
dst.Set(src)
return ptr
}
if src.IsValid() && src.Type().ConvertibleTo(dstType) {
dst.Set(src.Convert(dstType))
return ptr
}

// Special handling for slices, e.g. []interface{} -> []string
if outVal.Kind() == reflect.Slice && targetType.Kind() == reflect.Slice {
elemType := targetType.Elem()
n := outVal.Len()
dst := reflect.MakeSlice(targetType, n, n)

for i := 0; i < n; i++ {
src := outVal.Index(i)

// If it's interface{}, unwrap to the underlying concrete value.
if src.Kind() == reflect.Interface && !src.IsNil() {
src = reflect.ValueOf(src.Interface())
}

if src.Type().AssignableTo(elemType) {
dst.Index(i).Set(src)
continue
}

if src.Type().ConvertibleTo(elemType) {
dst.Index(i).Set(src.Convert(elemType))
continue
}

assert.Fail(t,
fmt.Sprintf("ExtractTo slice element type mismatch at index %d: cannot assign %v to %v",
i, src.Type(), elemType))
return nil
// slice handling: e.g. []interface{} -> []string
if src.IsValid() && src.Kind() == reflect.Slice && dstType.Kind() == reflect.Slice {
elemType := dstType.Elem()
n := src.Len()
out := reflect.MakeSlice(dstType, n, n)

for i := 0; i < n; i++ {
s := src.Index(i)

// unwrap interface{} elements
if s.Kind() == reflect.Interface && !s.IsNil() {
s = reflect.ValueOf(s.Interface())
}

if s.Type().AssignableTo(elemType) {
out.Index(i).Set(s)
continue
}
if s.Type().ConvertibleTo(elemType) {
out.Index(i).Set(s.Convert(elemType))
continue
}

targetVal.Elem().Set(dst)
return ptr
assert.Fail(t, fmt.Sprintf(
"ExtractTo slice element type mismatch at index %d: cannot assign %v to %v",
i, s.Type(), elemType,
))
return nil
}

assert.Fail(t,
fmt.Sprintf("ExtractTo type mismatch: cannot assign %v to %v",
outVal.Type(), targetType))
return nil
dst.Set(out)
return ptr
}

srcType := any("<invalid>")
if src.IsValid() {
srcType = src.Type()
}
assert.Fail(t, fmt.Sprintf("ExtractTo type mismatch: cannot assign %v to %v", srcType, dstType))
return nil
}
11 changes: 11 additions & 0 deletions pkg/httpassert/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package httpassert

import (
"fmt"
"regexp"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -45,3 +46,13 @@ func Contains(v string) Matcher {
return valid
}
}

// Regex checks if a string matches the given regular expression
// Example: ExpectJsonPath("$.data.name", httpassert.Regex("^foo.*bar$"))
func Regex(expr string) Matcher {
re := regexp.MustCompile(expr)

return func(t *testing.T, value any) bool {
return assert.Regexp(t, re, value)
}
}
19 changes: 19 additions & 0 deletions pkg/httpassert/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,32 @@ type Response interface {
JsonTemplateFile(path string, values map[string]any) Response
JsonFile(path string) Response

Header(name string, value any) Response

Body(body string) Response
GetJsonBodyObject(target any) Response
GetBody() string

Log() Response
}

func (r *responseImpl) Header(name string, value any) Response {
out := r.response.Header().Get(name)

switch v := value.(type) {
case Extractor:
v(r.t, out)
return r
case Matcher:
v(r.t, out)
return r

default:
assert.Equal(r.t, value, out)
return r
}
}

func (r *responseImpl) StatusCode(expected int) Response {
require.Equal(r.t, expected, r.response.Code)
return r
Expand Down
34 changes: 34 additions & 0 deletions pkg/httpassert/response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,40 @@ func TestResponse(t *testing.T) {
})
})

t.Run("Header", func(t *testing.T) {
router := http.NewServeMux()
router.HandleFunc("/api", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
})

t.Run("compare string", func(t *testing.T) {
m := New(t, router)

m.Get("/api").
Expect().
Header("Content-Type", "application/json")
})

t.Run("extract value to variable", func(t *testing.T) {
m := New(t, router)

var value string
m.Get("/api").
Expect().
Header("Content-Type", ExtractRegexTo(`application/(.*)`, &value))
assert.Equal(t, "json", value)
})

t.Run("use matcher", func(t *testing.T) {
m := New(t, router)

m.Get("/api").
Expect().
Header("Content-Type", Regex("application/[json]"))
})
})

t.Run("POST basic JSON", func(t *testing.T) {
request.Post("/json").
ContentType("application/json").
Expand Down
Loading