When doing any sort of tensor/array computation in `python`

(via `numpy`

,
`pytorch`

, `jax`

, or other), it's more frequent than not to encounter shape
errors like the one below

1import numpy as np23size1 = (2,3)4size2 = (4,3)56M1 = np.random.random(size=size1)7M2 = np.random.random(size=size2)89try:10 print(np.dot(M1,M2))11except Exception as e:12 print(e)

1shapes (2,3) and (4,3) not aligned: 3 (dim 1) != 4 (dim 0)

And most of the time, these kind of errors boil down to something like accidentally forgetting to do a reshape or transpose like so.

1import numpy as np23size1 = (2,3)4size2 = (4,3)56M1 = np.random.random(size=size1)7M2 = np.random.random(size=size2).T89try:10 print(np.dot(M1,M2))11except Exception as e:12 print(e)

1[[0.68812413 0.63491692 0.375332 1.22395427]2 [0.57381506 0.42578404 0.19132443 0.8889217 ]]

And while this is a mild case, shape bugs like these become more frequent as operations grow more complex and as more dimensions are involved.

Here's a slightly more complex example of a `Linear`

implementation in `numpy`

with a subtle shape bug.

1def Linear(A, x, b):2 """3 Takes matrix A (m x n) times a vector x (n x 1) and4 adds a bias. The resulting ndarray is then ravelled5 into a vector of size (m).6 """7 Ax = np.dot(A, x)8 Axb = np.add(Ax, b)9 return np.ravel(Axb)1011A = np.random.random(size=(4,4))12x = np.random.random(size=(4,1))13b = np.random.random(size=(4))1415result = Linear(A, x, b)16print(result)17print(result.shape)

1[1.18041914 1.87580329 0.93373901 1.48799234 1.4920404 2.187424552 1.24536027 1.79961361 2.29649806 2.99188221 2.04981793 2.604071273 1.31159899 2.00698314 1.06491886 1.6191722 ]4(16,)

The docstring of `Linear`

clearly says the result should be size `m`

(or
`4`

). But why then did we end up with a vector of size `16`

? If we dig into
each function we will eventually find that our problem is in how `numpy`

handles an `ndarray`

of a different shape.

If we break down `Linear`

, after `np.dot`

we have an `ndarray`

of shape
`(4,1)`

of which we do `np.add`

with a vector of shape `(4)`

. And here lies
our bug. We might naturally think that `np.add`

will do this addition element
wise, but instead we fell into an array broadcasting trap. Array broadcasting
are sets of rules `numpy`

uses to determine how to do arithmetic on different
shaped `ndarrays`

. So instead of doing our computation element wise, `numpy`

interprets this as doing a broadcast operation of addition, resulting in a
`(4,4)`

matrix, which subsequently gets "raveled" into a size `16`

vector.

Now to fix this is easy, we just need to initialize our `b`

variable to be of
shape `(4,1)`

so `numpy`

will interpret the `np.add`

as an element wise
addition.

1def Linear(A, x, b):2 """3 Takes matrix A (m x n) times a vector x (n x 1) and4 adds a bias. The resulting ndarray is then ravelled5 into a vector of size (m).6 """7 Ax = np.dot(A, x)8 Axb = np.add(Ax, b)9 return np.ravel(Axb)1011A = np.random.random(size=(4,4))12x = np.random.random(size=(4,1))13b = np.random.random(size=(4,1))1415result = Linear(A, x, b)16print(result)17print(result.shape)

1[1.15227694 1.24640271 0.63951685 1.13304944]2(4,)

We've solved the problem, but how can we be smarter to prevent this error from happening again?

## Existing ways to stop shape bugs

The simplest way we can try to stop this shape bug is with good docs. Ideally we should always have good docs, but we can make it a point to include what the shape expectations are like so:

1def Linear(A, x, b):2 """3 Args:4 A: ndarray of shape (M x N)5 x: ndarray of shape (N x 1)6 b: ndarray of shape (M x 1)78 Returns:9 Linear output ndarray of shape (M)10 """11 Ax = np.dot(A, x) # Shape (M x 1)12 Axb = np.add(Ax, b) # (M x 1) + (M x 1)13 return np.ravel(Axb) # Shape (M)

Now while informative, nothing is preventing us from encountering the same bug again. The only benefit this gives us, is making the debugging process a bit easier.

We can do better.

Another approach in addition to good docs that's more of a preventative action
is to use assertions. By sprinkling `assert`

throughout `Linear`

with an
informative error message, we can "fail early" and start debugging like so:

1def Linear(A, x, b):2 """3 Args:4 A: ndarray of shape (M x N)5 x: ndarray of shape (N x 1)6 b: ndarray of shape (M x 1)78 Returns:9 Linear output ndarray of shape (M)10 """11 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"12 Am, An = A.shape1314 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"15 Ax = np.dot(A, x) # Shape (M x 1)1617 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"18 result = np.add(Ax, b) # (M x 1) + (M x 1)1920 ravel_result = np.ravel(result)21 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"22 return ravel_result

At every step of this function we do an `assert`

to make sure all the
`ndarray`

shapes are what we expect.

As a result `Linear`

is a bit "safer". But compared to what we had originally,
this approach is much less readable. We also inherit some of the baggage that
comes with runtime error checking like:

**Incomplete checking**: Have we checked all expected shape failure modes?**Slow debugging cycles**: How many refactor->run cycles will we have to do pass the checks?**Additional testing**: Do we have to update our tests cover our runtime error checks?

Overall runtime error checking is not a bad thing. In most cases it's very necessary! But when it comes to shape errors, we can leverage an additional approach, static type checking.

Even though `python`

is a dynamically typed language, in `python>=3.5`

the
`typing`

module was introduced to enable static type checkers to validate type
hinted `python`

code. (See this video for more details)

Over time many third party libraries (like `numpy`

) have started to type hint
their codebases which we can use to our benefit.

In order to help us prevent shape errors, let's see what typing capabilities
exist in `numpy`

.

`dtype`

typing `numpy`

arrays

As of writing this post, `numpy==v1.24.2`

only supports typing on an
`ndarray`

's `dtype`

(`uint8`

, `float64`

, etc.).

Using `numpy`

's existing type hinting tooling, here's how we would include
`dtype`

type information to our `Linear`

example (note: there is an
intentional type error)

1from typing import TypeVar23import numpy as np4from numpy.typing import NDArray56GenericType = TypeVar("GenericType", bound=np.generic)789def Linear(10 A: NDArray[GenericType],11 x: NDArray[GenericType],12 b: NDArray[GenericType],13) -> NDArray[GenericType]:14 """15 Args:16 A: ndarray of shape (M x N)17 x: ndarray of shape (N x 1)18 b: ndarray of shape (M x 1)1920 Returns:21 Linear output ndarray of shape (M)22 """23 assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"24 Am, An = A.shape2526 assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"27 Ax: NDArray[GenericType] = np.dot(A, x) # Shape (M x 1)2829 assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"30 result: NDArray[GenericType] = np.add(Ax, b) # (M x 1) + (M x 1)3132 ravel_result: NDArray[GenericType] = np.ravel(result)33 assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"34 return ravel_result353637A: NDArray[np.float64] = np.random.standard_normal(size=(10, 10))38x: NDArray[np.float64] = np.random.standard_normal(size=(10, 1))39b: NDArray[np.float32] = np.random.standard_normal(size=(10, 1))40y: NDArray[np.float64] = Linear(A, x, b)41print(y)42print(y.dtype)

1[-1.81553298 -4.94471634 3.24041295 3.34200411 2.221593 7.591613722 3.1321597 -0.37862935 -1.98975116 1.57701057]3float64

Even though this code is "runnable" and doesn't produce an error, a type
checker like `pyright`

tells us a different story.

1pyright linear_bad_typing.py

1No configuration file found.2No pyproject.toml file found.3stubPath /mnt/typings is not a valid directory.4Assuming Python platform Linux5Searching for source files6Found 1 source file7pyright 1.1.2998/mnt/linear_bad_typing.py9 /mnt/linear_bad_typing.py:40:26 - error: Expression of type "ndarray[Any, dtype[float64]]" cannot be assigned to declared type "NDArray[float32]"10 "ndarray[Any, dtype[float64]]" is incompatible with "NDArray[float32]"11 TypeVar "_DType_co@ndarray" is covariant12 "dtype[float64]" is incompatible with "dtype[float32]"13 TypeVar "_DTypeScalar_co@dtype" is covariant14 "float64" is incompatible with "float32" (reportGeneralTypeIssues)15 /mnt/linear_bad_typing.py:41:39 - error: Argument of type "NDArray[float32]" cannot be assigned to parameter "b" of type "NDArray[GenericType@Linear]" in function "Linear"16 "NDArray[float32]" is incompatible with "NDArray[float64]"17 TypeVar "_DType_co@ndarray" is covariant18 "dtype[float32]" is incompatible with "dtype[float64]"19 TypeVar "_DTypeScalar_co@dtype" is covariant20 "float32" is incompatible with "float64" (reportGeneralTypeIssues)212 errors, 0 warnings, 0 informations22Completed in 0.606sec

`pyright`

has noticed that when we create our `b`

variable, we gave it a
`dtype`

type that is incompatible with `np.random.standard_normal`

.

Now we know to adjust the type hint of `b`

to be in line with the `dtype`

that
is expected of `np.random.standard_normal`

(`NDArray[np.float64]`

).

## Shape typing `numpy`

arrays

While `dtype`

typing is great, it's not the most useful for preventing shape
errors (like from our original example).

Ideally it would be great if in addition to a `dtype`

type, we can also
include information about an `ndarray`

's shape to do shape typing.

Shape typing is a technique used to annotate information about the
dimensionality and size of an array. In the context of `numpy`

and the
`python`

type hinting system, we can use shape typing catch shape errors
before runtime.

For more information about shape typing checkout this google doc on a shape typing syntax proposal by Matthew Rahtz, Jörg Bornschein, Vlad Mikulik, Tim Harley, Matthew Willson, Dimitrios Vytiniotis, Sergei Lebedev, Adam Paszke.

As we've seen, `numpy`

's `NDArray`

currently only supports `dtype`

typing and
doesn't have any of this kind of shape typing ability. But why is that? If we
dig into the definition of the `NDArray`

type:

1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)23if TYPE_CHECKING or sys.version_info >= (3, 9):4 _DType = np.dtype[ScalarType]5 NDArray = np.ndarray[Any, np.dtype[ScalarType]]6else:7 _DType = _GenericAlias(np.dtype, (ScalarType,))8 NDArray = _GenericAlias(np.ndarray, (Any, _DType))

And follow the definition of `np.ndarray`

...

1class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):

We can see that it looks like `numpy`

uses a `Shape`

type already! But
unfortunately if we look at the definition for this ...

1# TODO: Set the `bound` to something more suitable once we2# have proper shape support3_ShapeType = TypeVar("_ShapeType", bound=Any)4_ShapeType2 = TypeVar("_ShapeType2", bound=Any)

😭 Looks like we're stuck with `Any`

which doesn't add any useful shape
information on our types.

Luckily for us, we don't have to wait for shape support in `numpy`

. PEP 646 has
the base foundation for shape typing and has already been accepted into `python==3.11`

! And it's supported by `pyright`

! Theoretically these two things give
us most of the ingredients to do basic shape typing.

Now this blog post isn't about the details of PEP 646 or variadic generics. Understanding PEP 646 will help, but it's not needed to understand the rest of this post.

In order to add rudimentary shape typing to `numpy`

we can simply change the
`Any`

type in the `NDArray`

type definition to an unpacked variadic generic
like so:

1ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)2Shape = TypeVarTuple("Shape")34if TYPE_CHECKING or sys.version_info >= (3, 9):5 _DType = np.dtype[ScalarType]6 NDArray = np.ndarray[*Shape, np.dtype[ScalarType]]7else:8 _DType = _GenericAlias(np.dtype, (ScalarType,))9 NDArray = _GenericAlias(np.ndarray, (Any, _DType))

Doing so allows us to fill in a `Tuple`

based type (indicating shape) in an
`NDArray`

alongside a `dtype`

type. And shape typing with `Tuple`

's enables us
define function overloads which describe to a type checker the possible ways a
function can change the shape of an `NDArray`

.

Let's look at an example of using these concepts to type a wrapper function
for `np.random.standard_normal`

from our `Linear`

example with an intentional
type error:

1import numpy as np2from numpy.typing import NDArray3from typing import Tuple, TypeVar, Literal45# Generic dimension sizes types6T1 = TypeVar("T1", bound=int)7T2 = TypeVar("T2", bound=int)8T3 = TypeVar("T3", bound=int)910# Dimension types represented as typles11Shape = Tuple12Shape1D = Shape[T1]13Shape2D = Shape[T1, T2]14Shape3D = Shape[T1, T2, T3]15ShapeND = Shape[T1, ...]16ShapeNDType = TypeVar("ShapeNDType", bound=ShapeND)1718def rand_normal_matrix(shape: ShapeNDType) -> NDArray[ShapeNDType, np.float64]:19 """Return a random ND normal matrix."""20 return np.random.standard_normal(size=shape)2122# Yay correctly typed 2x2x2 cube!23LENGTH = Literal[2]24cube: NDArray[Shape3D[LENGTH, LENGTH, LENGTH], np.float64] = rand_normal_matrix((2,2,2))25print(cube)2627SIDE = Literal[4]2829# Uh oh the shapes won't match!30square: NDArray[Shape2D[SIDE, SIDE], np.float64] = rand_normal_matrix((3,3))31print(square)

Notice here there are no `assert`

statements. And instead of several comments
about shape, we indicate shape in the type hint.

Now while this code is "runnable", `pyright`

will tell us something else:

1py -m pyright bad_shape_typing.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.2997/mnt/bad_shape_typing.py8 /mnt/bad_shape_typing.py:30:71 - error: Argument of type "tuple[Literal[3], Literal[3]]" cannot be assigned to parameter "shape" of type "ShapeNDType@rand_normal_matrix" in function "rand_normal_matrix"9 Type "Shape2D[SIDE, SIDE]" cannot be assigned to type "tuple[Literal[3], Literal[3]]" (reportGeneralTypeIssues)101 error, 0 warnings, 0 informations11Completed in 0.535sec

`pyright`

is telling us we've incorrectly typed `square`

and that it's
incompatible with a `3x3`

shape. Now we know we need to go back and fix the
type to what a type checker should expect.

Huzzah shape typing!!

## Moar `numpy`

shape typing!

Now that we have shape typed one function, let's step it up a notch. Let's try
typing each `numpy`

function in our `Linear`

example to include shape
types. We've already typed `np.random.standard_normal`

, so next let's do
`np.dot`

.

If we look at the docs for `np.dot`

there are 5 type cases it supports.

Both arguments as

`1D`

arraysBoth arguments are

`2D`

arrays (resulting in a`matmul`

)Either arguments are scalars

Either argument is a

`ND`

array and the other is a`1D`

arrayOne argument is

`ND`

array and the other is`MD`

array

We can implement these cases as follows

1ShapeVarGen = TypeVarTuple("ShapeVarGen")23@overload4def dot(x1: NDArray[Shape1D[T1], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /) -> GenericDType:5 ...678@overload9def dot(10 x1: NDArray[Shape[T1, *ShapeVarGen], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /11) -> NDArray[Shape[*ShapeVarGen], GenericDType]:12 ...131415@overload16def dot(17 x1: NDArray[Shape2D[T1, T2], GenericDType],18 x2: NDArray[Shape2D[T2, T3], GenericDType],19 /,20) -> NDArray[Shape2D[T1, T3], GenericDType]:21 ...222324@overload25def dot(x1: GenericDType, x2: GenericDType, /) -> GenericDType:26 ...272829def dot(x1, x2):30 return np.dot(x1, x2)

The only case we can't implement is an `ND`

dimensional array with an `MD`

dimensional array. Ideally we would try implementing it like so:

1ShapeVarGen1 = TypeVarTuple("ShapeVarGen1")2ShapeVarGen2 = TypeVarTuple("ShapeVarGen2")34@overload5def dot(6 x1: NDArray[Shape[*ShapeVarGen1, T1], GenericDType], x2: NDArray[Shape[*ShapeVarGen2, T1, T2], GenericDType], /7) -> NDArray[Shape[*ShapeVarGen1, *ShapeVarGen2], GenericDType]:8 ...

But currently using multiple type variable tuples is not allowed. If you know
of another way to cover this case let me know! Luckily for our `Linear`

use
case, it only uses scalars, vectors, and matrices which is covered by our four
overloads.

Here's how we would use these `dot`

overloads to do the dot product between a
`2x3`

matrix and a `3x2`

matrix with type hints:

1import numpy as np2from numpy.typing import NDArray3from numpy_shape_typing.dot import dot4from numpy_shape_typing.types import ShapeNDType, Shape2D5from numpy_shape_typing.rand import rand_normal_matrix67from typing import Literal89ROWS = Literal[2]10COLS = Literal[3]11A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))12B: NDArray[Shape2D[COLS, ROWS], np.float64] = rand_normal_matrix((3,2))13C: NDArray[Shape2D[ROWS, ROWS], np.float64] = dot(A, B)14print(C)

And if we check with `pyright`

:

1py -m pyright good_dot.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.29970 errors, 0 warnings, 0 informations8Completed in 0.909sec

Everything looks good as it should!

And if we change the types to invalid matrix shapes:

1import numpy as np2from numpy.typing import NDArray3from numpy_shape_typing.dot import dot4from numpy_shape_typing.rand import rand_normal_matrix5from numpy_shape_typing.types import ShapeNDType, Shape2D67from typing import Literal89ROWS = Literal[2]10COLS = Literal[3]11SLICES = Literal[4]1213# uh oh based on these types we can't do a valid dot product!14A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))15B: NDArray[Shape2D[SLICES, COLS], np.float64] = rand_normal_matrix((4,3))16C: NDArray[Shape2D[ROWS, COLS], np.float64] = dot(A, B)17print(C)

And if we check with `pyright`

:

1py -m pyright ./bad_dot.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.2997/mnt/bad_dot.py8 /mnt/bad_dot.py:16:54 - error: Argument of type "NDArray[Shape2D[SLICES, COLS], float64]" cannot be assigned to parameter "x2" of type "GenericDType@dot" in function "dot"9 Type "NDArray[Shape2D[ROWS, COLS], float64]" cannot be assigned to type "NDArray[Shape2D[SLICES, COLS], float64]" (reportGeneralTypeIssues)101 error, 0 warnings, 0 informations11Completed in 0.908sec

`pyright`

let's us know that the types we are using are incorrect shapes based
on `np.dot`

's type overloads we've specified.

## Even moar `numpy`

shape typing!

The next function we are going to type is `np.add`

. The `numpy`

docs only show
two cases.

Two

`ND`

array arguments of the same shape are added element wiseTwo

`ND`

array arguments that are not the same shape must be broadcastable to a common shape

Covering the first case is easy, but the second case is much harder as we
would have to come up with a scheme to cover `numpy`

's array broadcasting
system. Currently `python==3.11`

's `typing`

doesn't have a generic way to
cover all the broadcasting rules. (If you know of a way let me know!)

However if we scope down the second case to only two dimensions, we can cover all the array broadcasting rules with a few overloads:

1from typing import overload23import numpy as np4from numpy.typing import NDArray56from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D, ShapeVarGen789@overload10def add(11 x1: NDArray[Shape2D[T1, T2], GenericDType],12 x2: NDArray[Shape1D[T2], GenericDType],13 /,14) -> NDArray[Shape2D[T1, T2], GenericDType]:15 ...161718@overload19def add(20 x1: NDArray[Shape1D[T2], GenericDType],21 x2: NDArray[Shape2D[T1, T2], GenericDType],22 /,23) -> NDArray[Shape2D[T1, T2], GenericDType]:24 ...252627@overload28def add(29 x1: NDArray[Shape2D[T1, T2], GenericDType],30 x2: NDArray[Shape1D[ONE], GenericDType],31 /,32) -> NDArray[Shape2D[T1, T2], GenericDType]:33 ...343536@overload37def add(38 x1: NDArray[Shape1D[ONE], GenericDType],39 x2: NDArray[Shape2D[T1, T2], GenericDType],40 /,41) -> NDArray[Shape2D[T1, T2], GenericDType]:42 ...434445@overload46def add(47 x1: NDArray[Shape2D[T1, T2], GenericDType],48 x2: NDArray[Shape2D[T1, ONE], GenericDType],49 /,50) -> NDArray[Shape2D[T1, T2], GenericDType]:51 ...525354@overload55def add(56 x1: NDArray[Shape2D[T1, T2], GenericDType],57 x2: NDArray[Shape2D[ONE, T2], GenericDType],58 /,59) -> NDArray[Shape2D[T1, T2], GenericDType]:60 ...616263@overload64def add(65 x1: NDArray[Shape2D[T1, ONE], GenericDType],66 x2: NDArray[Shape2D[T1, T2], GenericDType],67 /,68) -> NDArray[Shape2D[T1, T2], GenericDType]:69 ...707172@overload73def add(74 x1: NDArray[Shape2D[ONE, T2], GenericDType],75 x2: NDArray[Shape2D[T1, T2], GenericDType],76 /,77) -> NDArray[Shape2D[T1, T2], GenericDType]:78 ...798081@overload82def add(83 x1: GenericDType,84 x2: NDArray[Shape2D[T1, T2], GenericDType],85 /,86) -> NDArray[Shape2D[T1, T2], GenericDType]:87 ...888990@overload91def add(92 x1: NDArray[Shape2D[T1, T2], GenericDType],93 x2: GenericDType,94 /,95) -> NDArray[Shape2D[T1, T2], GenericDType]:96 ...979899@overload100def add(101 x1: NDArray[*ShapeVarGen, GenericDType],102 x2: NDArray[*ShapeVarGen, GenericDType],103 /,104) -> NDArray[*ShapeVarGen, GenericDType]:105 ...106107108def add(x1, x2):109 return np.add(x1, x2)

Using these overloads, here is how we would catch unexpected array broadcasts
(similar to the one from our original `Linear`

example).

1from typing import Literal23import numpy as np4from numpy.typing import NDArray56from numpy_shape_typing.add import add7from numpy_shape_typing.dot import dot8from numpy_shape_typing.rand import rand_normal_matrix9from numpy_shape_typing.types import ONE, Shape1D, Shape2D1011COLS = Literal[4]12A: NDArray[Shape2D[COLS, COLS], np.float64] = rand_normal_matrix((4, 4))13B: NDArray[Shape2D[ONE, COLS], np.float64] = rand_normal_matrix((1, 4))14C: NDArray[Shape2D[ONE, COLS], np.float64] = add(A, B)15print(C)

In the example above, our output is a `4x4`

matrix, but what we want from our
types is an output shape of `4x1`

. Let's see what `pyright`

says

1py -m pyright unnexpected_broadcast.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.2997/mnt/unnexpected_broadcast.py8 /mnt/unnexpected_broadcast.py:14:50 - error: Argument of type "NDArray[Shape2D[COLS, COLS], float64]" cannot be assigned to parameter "x1" of type "NDArray[*ShapeVarGen@add, GenericDType@add]" in function "add"9 "NDArray[Shape2D[COLS, COLS], float64]" is incompatible with "NDArray[Shape2D[ONE, COLS], float64]"10 TypeVar "_ShapeType@ndarray" is invariant11 "*tuple[Shape2D[COLS, COLS]]" is incompatible with "*tuple[Shape2D[ONE, COLS]]"12 Tuple entry 1 is incorrect type13 "Shape2D[COLS, COLS]" is incompatible with "Shape2D[ONE, COLS]" (reportGeneralTypeIssues)141 error, 0 warnings, 0 informations15Completed in 2.757sec

`pyright`

informs us that our shapes are off and that we got broadcasted to a
`4x4`

! Huzzah shape typing!

## Hitting the limitations of shape typing 😿

The last function we will type to finish of our `Linear`

example is
`np.ravel`

. However this is where we start hitting some major limitations of
shape typing as they exist today in `python`

and `numpy`

.

From the numpy docs on `np.ravel`

the only case we need to cover is that any
`ND`

array gets collapsed into a `1D`

array of size of the total number of
elements. Luckily all the information to compute the final `1D`

size is just
the product of all the input dimension sizes.

Ideally we would try to write code that looks something like this:

1ShapeVarGen = TypeVarTuple("ShapeVarGen")23@overload4def ravel(5 arr: NDArray[Shape[*ShapeVarGen], GenericDType]6) -> NDArray[Shape1D[Product[*ShapeVarGen]], GenericDType]:7 ...

But unfortunately `python`

's `typing`

package currently doesn't have a notion
of a `Product`

type that provides a way to do algebraic typing.

However for the sake of completion we can fake it!

If we scope down from a generic `ND`

typing of `np.ravel`

to support up to two
dimensions and limit the size of the output dimension to some maximum number,
we can overload all the possible factors that multiply to the output dimension
size. We would effectively be typing a multiplication table 😆, but it will
work and get us to a "partially" typed `np.ravel`

.

Here's how we can do it.

First we create a bunch of `Literal`

types (our factors):

1ZERO = Literal[0]2ONE = Literal[1]3TWO = Literal[2]4THREE = Literal[3]5FOUR = Literal[4]6...

Then we define "multiply" types for factor pairs of numbers:

1SHAPE_2D_MUL_TO_ONE = TypeVar(2 "SHAPE_2D_MUL_TO_ONE",3 bound=Shape2D[Literal[ONE], Literal[ONE]],4)5SHAPE_2D_MUL_TO_TWO = TypeVar(6 "SHAPE_2D_MUL_TO_TWO",7 bound=Union[Shape2D[Literal[ONE], Literal[TWO]], Shape2D[Literal[TWO], Literal[ONE]]],8)9SHAPE_2D_MUL_TO_THREE = TypeVar(10 "SHAPE_2D_MUL_TO_THREE",11 bound=Union[Shape2D[Literal[ONE], Literal[THREE]], Shape2D[Literal[THREE], Literal[ONE]]],12)13SHAPE_2D_MUL_TO_FOUR = TypeVar(14 "SHAPE_2D_MUL_TO_FOUR",15 bound=Union[16 Shape2D[Literal[ONE], Literal[FOUR]],17 Shape2D[Literal[TWO], Literal[TWO]],18 Shape2D[Literal[FOUR], Literal[ONE]],19 ],20)

Then lastly we wire these types up into individual `ravel`

overloads (and
cover a few generic ones while we're at it):

1@overload2def ravel(arr: NDArray[SHAPE_2D_MUL_TO_ONE, GenericDType]) -> NDArray[Shape1D[ONE], GenericDType]:3 ...456@overload7def ravel(arr: NDArray[SHAPE_2D_MUL_TO_TWO, GenericDType]) -> NDArray[Shape1D[TWO], GenericDType]:8 ...91011@overload12def ravel(arr: NDArray[SHAPE_2D_MUL_TO_THREE, GenericDType]) -> NDArray[Shape1D[THREE], GenericDType]:13 ...141516@overload17def ravel(arr: NDArray[SHAPE_2D_MUL_TO_FOUR, GenericDType]) -> NDArray[Shape1D[FOUR], GenericDType]:18 ...1920@overload21def ravel(arr: NDArray[Shape2D[T1, ONE], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:22 ...232425@overload26def ravel(arr: NDArray[Shape2D[ONE, T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:27 ...282930@overload31def ravel(arr: NDArray[Shape1D[T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:32 ...

Now we can rinse and repeat for as many numbers as we like!

Here is how we'd use this typing to catch a shape type error with `ravel`

:

1import numpy as np2from numpy.typing import NDArray34from numpy_shape_typing.rand import rand_normal_matrix5from numpy_shape_typing.ravel import ravel6from numpy_shape_typing.types import FOUR, SEVEN, TWO, Shape1D, Shape2D78A: NDArray[Shape2D[TWO, FOUR], np.float64] = rand_normal_matrix((2, 4))9B: NDArray[Shape1D[SEVEN], np.float64] = ravel(A)10print(B)

1py -m pyright raveling.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.2997/mnt/raveling.py8 /mnt/raveling.py:9:42 - error: Expression of type "NDArray[Shape1D[EIGHT], float64]" cannot be assigned to declared type "NDArray[Shape1D[SEVEN], float64]"9 "NDArray[Shape1D[EIGHT], float64]" is incompatible with "NDArray[Shape1D[SEVEN], float64]"10 TypeVar "_ShapeType@ndarray" is invariant11 "*tuple[Shape1D[EIGHT]]" is incompatible with "*tuple[Shape1D[SEVEN]]"12 Tuple entry 1 is incorrect type13 "Shape1D[EIGHT]" is incompatible with "Shape1D[SEVEN]" (reportGeneralTypeIssues)141 error, 0 warnings, 0 informations15Completed in 0.933sec

## Putting it all together

So far we've gone through typing a small subset of `numpy`

's functions
(`np.random.standard_normal`

, `np.dot`

, `np.add`

, and `np.ravel`

in all).

Now we can chain these typed functions together to form a typed `Linear`

implementation like so:

1from typing import Literal23import numpy as np4from numpy.typing import NDArray56from numpy_shape_typing.add import add7from numpy_shape_typing.dot import dot8from numpy_shape_typing.rand import rand_normal_matrix9from numpy_shape_typing.ravel import ravel10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D111213def Linear(14 A: NDArray[Shape2D[T1, T2], GenericDType],15 x: NDArray[Shape2D[T2, ONE], GenericDType],16 b: NDArray[Shape2D[T1, ONE], GenericDType],17) -> NDArray[Shape1D[T1], GenericDType]:18 Ax = dot(A, x)19 Axb = add(Ax, b)20 return ravel(Axb)212223IN_DIM = Literal[3]24in_dim: IN_DIM = 32526OUT_DIM = Literal[4]27out_dim: OUT_DIM = 42829# bad type >:(30BAD_OUT_DIM = Literal[5]3132A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))33x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))34b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))3536# this is a bad type!37y: NDArray[Shape1D[BAD_OUT_DIM], np.float64] = Linear(A, x, b)

I've included an intentional type error which should be caught by `pyright`

like so:

1py -m pyright linear_type_bad.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.2997/mnt/linear_type_bad.py8 /mnt/linear_type_bad.py:37:55 - error: Argument of type "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" cannot be assigned to parameter "A" of type "NDArray[Shape2D[T1@Linear, T2@Linear], GenericDType@Linear]" in function "Linear"9 "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, IN_DIM], float64]"10 TypeVar "_ShapeType@ndarray" is invariant11 "*tuple[Shape2D[OUT_DIM, IN_DIM]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, IN_DIM]]"12 Tuple entry 1 is incorrect type13 "Shape2D[OUT_DIM, IN_DIM]" is incompatible with "Shape2D[BAD_OUT_DIM, IN_DIM]" (reportGeneralTypeIssues)14 /mnt/linear_type_bad.py:37:61 - error: Argument of type "NDArray[Shape2D[OUT_DIM, ONE], float64]" cannot be assigned to parameter "b" of type "NDArray[Shape2D[T1@Linear, ONE], GenericDType@Linear]" in function "Linear"15 "NDArray[Shape2D[OUT_DIM, ONE], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, ONE], float64]"16 TypeVar "_ShapeType@ndarray" is invariant17 "*tuple[Shape2D[OUT_DIM, ONE]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, ONE]]"18 Tuple entry 1 is incorrect type19 "Shape2D[OUT_DIM, ONE]" is incompatible with "Shape2D[BAD_OUT_DIM, ONE]" (reportGeneralTypeIssues)202 errors, 0 warnings, 0 informations21Completed in 8.155sec

And huzzah again! `pyright`

has caught the shape type error!

And now we can fix this shape error by changing `BAD_OUT_DIM`

to the correct
output dimension size.

1from typing import Literal23import numpy as np4from numpy.typing import NDArray56from numpy_shape_typing.add import add7from numpy_shape_typing.dot import dot8from numpy_shape_typing.rand import rand_normal_matrix9from numpy_shape_typing.ravel import ravel10from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D111213def Linear(14 A: NDArray[Shape2D[T1, T2], GenericDType],15 x: NDArray[Shape2D[T2, ONE], GenericDType],16 b: NDArray[Shape2D[T1, ONE], GenericDType],17) -> NDArray[Shape1D[T1], GenericDType]:18 """19 Args:20 A: ndarray of shape (M x N)21 x: ndarray of shape (N x 1)22 b: ndarray of shape (M x 1)2324 Returns:25 Linear output ndarray of shape (M)26 """27 Ax = dot(A, x)28 Axb = add(Ax, b)29 return ravel(Axb)303132IN_DIM = Literal[3]33in_dim: IN_DIM = 33435OUT_DIM = Literal[4]36out_dim: OUT_DIM = 43738A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))39x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))40b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))41y: NDArray[Shape1D[OUT_DIM], np.float64] = Linear(A, x, b)

And if we check with `pyright`

.

1py -m pyright linear_type_good.py --lib

1No configuration file found.2No pyproject.toml file found.3Assuming Python platform Linux4Searching for source files5Found 1 source file6pyright 1.1.29970 errors, 0 warnings, 0 informations8Completed in 8.116sec

`pyright`

tells us that our types are consistent!

## What's next?

You tell me! Many open source scientific computing libraries have GitHub issues about shape typing such as:

`numpy`

: https://github.com/numpy/numpy/issues/16544`jax`

: https://github.com/google/jax/issues/12049`pytorch`

: https://github.com/pytorch/pytorch/issues/33953

So it's well recognized as a desirable feature. Some of the major technical hurdles we still need to overcome are:

- PEP 646 in mypy
- Arithmetic between Literal's
- Multiple type variable tuples
- Type bounds for variadic generics

Once these hurdles are overcome I don't see any blockers stopping projects
like `numpy`

from being fully shape typed.

This post and accompanying repo is just a sample form of what shape typing
might become. With future PEPs and work on the `python`

type hinting system,
we'll hopefully make our code incrementally safer.

Thanks for reading! (っ◔◡◔)っ ♥