Skip to content

Commit

Permalink
fix: port over more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AidanAbd committed Jun 7, 2024
1 parent c82f55b commit 7a3e68e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 6 deletions.
73 changes: 68 additions & 5 deletions latch_data_validation/data_validation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import collections.abc
import dataclasses
from enum import Enum
from itertools import chain
from types import NoneType, UnionType
import sys
from types import FrameType, NoneType, UnionType
from typing import (
Any,
ForwardRef,
Iterable,
Literal,
Mapping,
NewType,
Sequence,
TypeAlias,
TypeVar,
Expand All @@ -15,11 +19,31 @@
get_origin,
get_type_hints,
)
import typing

from opentelemetry.trace import get_tracer

tracer = get_tracer(__name__)

forward_frames: dict[int, FrameType] = {}
real_init = ForwardRef.__init__


def init(self, *args, **kwargs):
cur = sys._getframe().f_back
assert cur is not None

typing_filename = cur.f_code.co_filename
while cur is not None and cur.f_code.co_filename == typing_filename:
cur = cur.f_back

if cur is not None:
forward_frames[id(self)] = cur
real_init(self, *args, **kwargs)


ForwardRef.__init__ = init

T = TypeVar("T")

JsonArray: TypeAlias = Sequence["JsonValue"]
Expand Down Expand Up @@ -127,12 +151,39 @@ def __str__(self):
return f"\n{self.explain()}"


# todo(maximsmol): generics
# todo(maximsmol): typing
def untraced_validate(x: JsonValue, cls: type[T]) -> T:
if dataclasses.is_dataclass(cls):
if isinstance(x, cls):
return x
if isinstance(cls, ForwardRef):
fr = typing.cast(ForwardRef, cls)

frame = forward_frames.get(id(cls))
if frame is None:
raise DataValidationError("untraced ForwardRef", x, cls)

f_globals = frame.f_globals
f_locals = frame.f_locals

next = f_globals.get(fr.__forward_arg__)
if next is None:
next = f_locals.get(fr.__forward_arg__)

if next is None:
raise DataValidationError("unresolvable ForwardRef", x, cls)

return untraced_validate(x, next)

if cls is Any:
return x

if isinstance(cls, NewType):
# todo(maximsmol): this probably needs to be typed properly on the gql client layer like enums
return untraced_validate(x, cls.__supertype__)

if dataclasses.is_dataclass(cls):
if not isinstance(x, dict):
raise DataValidationError("expected an object", x, cls)

Expand Down Expand Up @@ -222,8 +273,8 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:
)

if issubclass(origin, collections.abc.Mapping):
if not isinstance(x, collections.abc.Mapping):
raise DataValidationError("expected a dict", x, cls)
if not isinstance(x, origin):
raise DataValidationError("mapping type does not match", x, cls)

key_type, value_type = get_args(cls)

Expand Down Expand Up @@ -255,7 +306,7 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:

if len(errors) > 0:
raise DataValidationError(
"list items did not match schema", x, cls, children=errors
"mapping items did not match schema", x, cls, children=errors
)

if origin is collections.abc.Mapping:
Expand Down Expand Up @@ -291,6 +342,10 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:
if issubclass(origin, Iterable):
if not isinstance(x, collections.abc.Iterable):
raise DataValidationError("expected an iterable", x, cls)
if isinstance(x, str):
raise DataValidationError("iterable type does not match", x, cls)
if not isinstance(x, origin):
raise DataValidationError("iterable type does not match", x, cls)

item_type = get_args(cls)[0]

Expand All @@ -307,7 +362,10 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:
"list items did not match schema", x, cls, children=errors
)

if origin is collections.abc.Iterable:
if any(
origin is x
for x in [collections.abc.Iterable, collections.abc.Sequence]
):
return list(res)

return origin(res)
Expand Down Expand Up @@ -365,6 +423,10 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:

return cls(**fields)

# todo(maximsmol): make conversions to enums and dataclasses optional
if issubclass(cls, Enum):
return cls(x)

if issubclass(cls, bool):
if not isinstance(x, bool):
raise DataValidationError("expected a boolean", x, cls)
Expand Down Expand Up @@ -397,6 +459,7 @@ def untraced_validate(x: JsonValue, cls: type[T]) -> T:
raise DataValidationError("[!Internal Error!] unknown type", x, cls)



def validate(x: JsonValue, cls: type[T]) -> T:
with tracer.start_as_current_span(
validate.__qualname__,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "latch-data-validation"
version = "0.1.8"
version = "0.1.9"
description = "Data validation for latch python backend services"
authors = ["Max Smolin <[email protected]>"]
license = "CC0 1.0"
Expand Down

0 comments on commit 7a3e68e

Please sign in to comment.