-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
schema.go
94 lines (80 loc) · 2.3 KB
/
schema.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package bridge
import (
"errors"
"github.com/DataDog/go-python3"
"github.com/apache/arrow/go/arrow"
)
// PySchemaFromPyTable returns a pyarrow schema from a pyarrow Table.
func PySchemaFromPyTable(pyTable *python3.PyObject) (*python3.PyObject, error) {
pySchema := pyTable.GetAttrString("schema")
if pySchema == nil {
return nil, errors.New("could not get pySchema")
}
return pySchema, nil
}
// PySchemaToSchema given a Python schema gets the Go Arrow schema.
func PySchemaToSchema(pySchema *python3.PyObject) (*arrow.Schema, error) {
// start with the field names
pyFieldNames, err := getPyFieldNames(pySchema)
if err != nil {
return nil, err
}
defer func() {
for i := range pyFieldNames {
pyFieldNames[i].DecRef()
}
}()
// Get the fields
fields, err := getFields(pySchema, pyFieldNames)
if err != nil {
return nil, err
}
return arrow.NewSchema(fields, nil), nil
}
func getPyFieldNames(pySchema *python3.PyObject) ([]*python3.PyObject, error) {
pyFieldNames := pySchema.GetAttrString("names")
if pyFieldNames == nil {
return nil, errors.New("could not get pyFieldNames")
}
defer pyFieldNames.DecRef()
// verify the result is a list
if !python3.PyList_Check(pyFieldNames) {
return nil, errors.New("not a list of field names")
}
length := python3.PyList_Size(pyFieldNames)
pyNames := make([]*python3.PyObject, 0, length)
for i := 0; i < length; i++ {
pyName := python3.PyList_GetItem(pyFieldNames, i)
if pyName == nil {
return nil, errors.New("could not get name")
}
pyName.IncRef()
// pyNames[i] = pyName
pyNames = append(pyNames, pyName)
}
return pyNames, nil
}
func getFields(pySchema *python3.PyObject, pyFieldNames []*python3.PyObject) ([]arrow.Field, error) {
fields := make([]arrow.Field, 0, len(pyFieldNames))
for _, pyFieldName := range pyFieldNames {
field, err := getField(pySchema, pyFieldName)
if err != nil {
return nil, err
}
// fields[i] = *field
fields = append(fields, *field)
}
return fields, nil
}
func getField(schema *python3.PyObject, fieldName *python3.PyObject) (*arrow.Field, error) {
pyField := CallPyFunc(schema, "field_by_name", fieldName)
if pyField == nil {
return nil, errors.New("could not get pyField")
}
defer pyField.DecRef()
field, err := PyFieldToField(pyField)
if err != nil {
return nil, err
}
return field, nil
}