import * as React from 'react'
  /* @jsx mdx */
import { mdx } from '@mdx-js/react';
/* @jsxRuntime classic */

/* @jsx mdx */

import DefaultLayout from "/home/node/work/src/templates/post.template.tsx";
export const _frontmatter = {};
const layoutProps = {
  _frontmatter
};
const MDXLayout = DefaultLayout;
export default function MDXContent({
  components,
  ...props
}) {
  return <MDXLayout {...layoutProps} {...props} components={components} mdxType="MDXLayout">


    <p>{`When doing any sort of tensor/array computation in `}<inlineCode parentName="p">{`python`}</inlineCode>{` (via `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`,
`}<inlineCode parentName="p">{`pytorch`}</inlineCode>{`, `}<inlineCode parentName="p">{`jax`}</inlineCode>{`, or other), it's more frequent than not to encounter shape
errors like the one below`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np

size1 = (2,3)
size2 = (4,3)

M1 = np.random.random(size=size1)
M2 = np.random.random(size=size2)

try:
    print(np.dot(M1,M2))
except Exception as e:
    print(e)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`shapes (2,3) and (4,3) not aligned: 3 (dim 1) != 4 (dim 0)
`}</code></pre>
    <p>{`And most of the time, these kind of errors boil down to something like
accidentally forgetting to do a reshape or transpose like so.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np

size1 = (2,3)
size2 = (4,3)

M1 = np.random.random(size=size1)
M2 = np.random.random(size=size2).T

try:
    print(np.dot(M1,M2))
except Exception as e:
    print(e)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`[[0.68812413 0.63491692 0.375332   1.22395427]
 [0.57381506 0.42578404 0.19132443 0.8889217 ]]
`}</code></pre>
    <p>{`And while this is a mild case, shape bugs like these become more frequent as
operations grow more complex and as more dimensions are involved.`}</p>
    <p>{`Here's a slightly more complex example of a `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` implementation in `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`
with a subtle shape bug.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`def Linear(A, x, b):
    """
    Takes matrix A (m x n) times a vector x (n x 1) and
    adds a bias. The resulting ndarray is then ravelled
    into a vector of size (m).
    """
    Ax = np.dot(A, x)
    Axb = np.add(Ax, b)
    return np.ravel(Axb)

A = np.random.random(size=(4,4))
x = np.random.random(size=(4,1))
b = np.random.random(size=(4))

result = Linear(A, x, b)
print(result)
print(result.shape)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`[1.18041914 1.87580329 0.93373901 1.48799234 1.4920404  2.18742455
 1.24536027 1.79961361 2.29649806 2.99188221 2.04981793 2.60407127
 1.31159899 2.00698314 1.06491886 1.6191722 ]
(16,)
`}</code></pre>
    <p>{`The docstring of `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` clearly says the result should be size `}<inlineCode parentName="p">{`m`}</inlineCode>{` (or
`}<inlineCode parentName="p">{`4`}</inlineCode>{`). But why then did we end up with a vector of size `}<inlineCode parentName="p">{`16`}</inlineCode>{`? If we dig into
each function we will eventually find that our problem is in how `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`
handles an `}<inlineCode parentName="p">{`ndarray`}</inlineCode>{` of a different shape.`}</p>
    <p>{`If we break down `}<inlineCode parentName="p">{`Linear`}</inlineCode>{`, after `}<inlineCode parentName="p">{`np.dot`}</inlineCode>{` we have an `}<inlineCode parentName="p">{`ndarray`}</inlineCode>{` of shape
`}<inlineCode parentName="p">{`(4,1)`}</inlineCode>{` of which we do `}<inlineCode parentName="p">{`np.add`}</inlineCode>{` with a vector of shape `}<inlineCode parentName="p">{`(4)`}</inlineCode>{`. And here lies
our bug. We might naturally think that `}<inlineCode parentName="p">{`np.add`}</inlineCode>{` will do this addition element
wise, but instead we fell into an `}<a parentName="p" {...{
        "href": "https://numpy.org/doc/stable/user/basics.broadcasting.html#broadcastable-arrays"
      }}>{`array broadcasting`}</a>{` trap. Array broadcasting
are sets of rules `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` uses to determine how to do arithmetic on different
shaped `}<inlineCode parentName="p">{`ndarrays`}</inlineCode>{`. So instead of doing our computation element wise, `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`
interprets this as doing a broadcast operation of addition, resulting in a
`}<inlineCode parentName="p">{`(4,4)`}</inlineCode>{` matrix, which subsequently gets "raveled" into a size `}<inlineCode parentName="p">{`16`}</inlineCode>{` vector.`}</p>
    <p>{`Now to fix this is easy, we just need to initialize our `}<inlineCode parentName="p">{`b`}</inlineCode>{` variable to be of
shape `}<inlineCode parentName="p">{`(4,1)`}</inlineCode>{` so `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` will interpret the `}<inlineCode parentName="p">{`np.add`}</inlineCode>{` as an element wise
addition.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`def Linear(A, x, b):
    """
    Takes matrix A (m x n) times a vector x (n x 1) and
    adds a bias. The resulting ndarray is then ravelled
    into a vector of size (m).
    """
    Ax = np.dot(A, x)
    Axb = np.add(Ax, b)
    return np.ravel(Axb)

A = np.random.random(size=(4,4))
x = np.random.random(size=(4,1))
b = np.random.random(size=(4,1))

result = Linear(A, x, b)
print(result)
print(result.shape)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`[1.15227694 1.24640271 0.63951685 1.13304944]
(4,)
`}</code></pre>
    <p>{`We've solved the problem, but how can we be smarter to prevent this error from
happening again?`}</p>
    <h2 {...{
      "id": "existing-ways-to-stop-shape-bugs"
    }}>{`Existing ways to stop shape bugs`}</h2>
    <p>{`The simplest way we can try to stop this shape bug is with good docs. Ideally
we should always have good docs, but we can make it a point to include what
the shape expectations are like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`def Linear(A, x, b):
    """
    Args:
        A: ndarray of shape (M x N)
        x: ndarray of shape (N x 1)
        b: ndarray of shape (M x 1)

    Returns:
        Linear output ndarray of shape (M)
    """
    Ax = np.dot(A, x) # Shape (M x 1)
    Axb = np.add(Ax, b) # (M x 1) + (M x 1)
    return np.ravel(Axb) # Shape (M)
`}</code></pre>
    <p>{`Now while informative, nothing is preventing us from encountering the same bug
again. The only benefit this gives us, is making the debugging process a
bit easier.`}</p>
    <p>{`We can do better.`}</p>
    <p>{`Another approach in addition to good docs that's more of a preventative action
is to use assertions. By sprinkling `}<inlineCode parentName="p">{`assert`}</inlineCode>{` throughout `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` with an
informative error message, we can "fail early" and start debugging like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`def Linear(A, x, b):
    """
    Args:
        A: ndarray of shape (M x N)
        x: ndarray of shape (N x 1)
        b: ndarray of shape (M x 1)

    Returns:
        Linear output ndarray of shape (M)
    """
    assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"
    Am, An = A.shape

    assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"
    Ax = np.dot(A, x) # Shape (M x 1)

    assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"
    result = np.add(Ax, b) # (M x 1) + (M x 1)

    ravel_result = np.ravel(result)
    assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"
    return ravel_result
`}</code></pre>
    <p>{`At every step of this function we do an `}<inlineCode parentName="p">{`assert`}</inlineCode>{` to make sure all the
`}<inlineCode parentName="p">{`ndarray`}</inlineCode>{` shapes are what we expect.`}</p>
    <p>{`As a result `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` is a bit "safer". But compared to what we had originally,
this approach is much less readable. We also inherit some of the baggage that
comes with runtime error checking like:`}</p>
    <ul>
      <li parentName="ul">
        <p parentName="li"><strong parentName="p">{`Incomplete checking`}</strong>{`: Have we checked all expected shape failure modes?`}</p>
      </li>
      <li parentName="ul">
        <p parentName="li"><strong parentName="p">{`Slow debugging cycles`}</strong>{`: How many refactor-`}{`>`}{`run cycles will we have to do
pass the checks?`}</p>
      </li>
      <li parentName="ul">
        <p parentName="li"><strong parentName="p">{`Additional testing`}</strong>{`: Do we have to update our tests cover our runtime error
checks?`}</p>
      </li>
    </ul>
    <p>{`Overall runtime error checking is not a bad thing. In most cases it's very
necessary! But when it comes to shape errors, we can leverage an additional
approach, static type checking.`}</p>
    <p>{`Even though `}<inlineCode parentName="p">{`python`}</inlineCode>{` is a dynamically typed language, in `}<inlineCode parentName="p">{`python>=3.5`}</inlineCode>{` the
`}<inlineCode parentName="p">{`typing`}</inlineCode>{` module was introduced to enable static type checkers to validate type
hinted `}<inlineCode parentName="p">{`python`}</inlineCode>{` code. (See `}<a parentName="p" {...{
        "href": "https://www.youtube.com/watch?v=2wDvzy6Hgxg"
      }}>{`this video`}</a>{` for more details)`}</p>
    <p>{`Over time many third party libraries (like `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`) have started to type hint
their codebases which we can use to our benefit.`}</p>
    <p>{`In order to help us prevent shape errors, let's see what typing capabilities
exist in `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`.`}</p>
    <h2 {...{
      "id": "dtype-typing-numpy-arrays"
    }}><inlineCode parentName="h2">{`dtype`}</inlineCode>{` typing `}<inlineCode parentName="h2">{`numpy`}</inlineCode>{` arrays`}</h2>
    <p>{`As of writing this post, `}<inlineCode parentName="p">{`numpy==v1.24.2`}</inlineCode>{` only supports typing on an
`}<inlineCode parentName="p">{`ndarray`}</inlineCode>{`'s `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` (`}<inlineCode parentName="p">{`uint8`}</inlineCode>{`, `}<inlineCode parentName="p">{`float64`}</inlineCode>{`, etc.).`}</p>
    <p>{`Using `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`'s existing type hinting tooling, here's how we would include
`}<inlineCode parentName="p">{`dtype`}</inlineCode>{` type information to our `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` example (note: there is an
intentional type error)`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`from typing import TypeVar

import numpy as np
from numpy.typing import NDArray

GenericType = TypeVar("GenericType", bound=np.generic)


def Linear(
    A: NDArray[GenericType],
    x: NDArray[GenericType],
    b: NDArray[GenericType],
) -> NDArray[GenericType]:
    """
    Args:
        A: ndarray of shape (M x N)
        x: ndarray of shape (N x 1)
        b: ndarray of shape (M x 1)

    Returns:
        Linear output ndarray of shape (M)
    """
    assert len(A.shape) == 2, f"A must be of dim 2, not {len(A.shape)}"
    Am, An = A.shape

    assert x.shape == (An, 1), f"X must be shape ({An}, 1) to do dot"
    Ax: NDArray[GenericType] = np.dot(A, x)  # Shape (M x 1)

    assert b.shape == (Am, 1), f"Bias term must be shape ({Am}, 1)"
    result: NDArray[GenericType] = np.add(Ax, b)  # (M x 1) + (M x 1)

    ravel_result: NDArray[GenericType] = np.ravel(result)
    assert ravel_result.shape == (Am,), f"Uh oh, ravel result is shape {ravel_result.shape} and not {(Am,)}"
    return ravel_result


A: NDArray[np.float64] = np.random.standard_normal(size=(10, 10))
x: NDArray[np.float64] = np.random.standard_normal(size=(10, 1))
b: NDArray[np.float32] = np.random.standard_normal(size=(10, 1))
y: NDArray[np.float64] = Linear(A, x, b)
print(y)
print(y.dtype)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`[-1.81553298 -4.94471634  3.24041295  3.34200411  2.221593    7.59161372
  3.1321597  -0.37862935 -1.98975116  1.57701057]
float64
`}</code></pre>
    <p>{`Even though this code is "runnable" and doesn't produce an error, a type
checker like `}<inlineCode parentName="p">{`pyright`}</inlineCode>{` tells us a different story.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`pyright linear_bad_typing.py
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
stubPath /mnt/typings is not a valid directory.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/linear_bad_typing.py
  /mnt/linear_bad_typing.py:40:26 - error: Expression of type "ndarray[Any, dtype[float64]]" cannot be assigned to declared type "NDArray[float32]"
    "ndarray[Any, dtype[float64]]" is incompatible with "NDArray[float32]"
      TypeVar "_DType_co@ndarray" is covariant
        "dtype[float64]" is incompatible with "dtype[float32]"
          TypeVar "_DTypeScalar_co@dtype" is covariant
            "float64" is incompatible with "float32" (reportGeneralTypeIssues)
  /mnt/linear_bad_typing.py:41:39 - error: Argument of type "NDArray[float32]" cannot be assigned to parameter "b" of type "NDArray[GenericType@Linear]" in function "Linear"
    "NDArray[float32]" is incompatible with "NDArray[float64]"
      TypeVar "_DType_co@ndarray" is covariant
        "dtype[float32]" is incompatible with "dtype[float64]"
          TypeVar "_DTypeScalar_co@dtype" is covariant
            "float32" is incompatible with "float64" (reportGeneralTypeIssues)
2 errors, 0 warnings, 0 informations
Completed in 0.606sec
`}</code></pre>
    <p><inlineCode parentName="p">{`pyright`}</inlineCode>{` has noticed that when we create our `}<inlineCode parentName="p">{`b`}</inlineCode>{` variable, we gave it a
`}<inlineCode parentName="p">{`dtype`}</inlineCode>{` type that is incompatible with `}<inlineCode parentName="p">{`np.random.standard_normal`}</inlineCode>{`.`}</p>
    <p>{`Now we know to adjust the type hint of `}<inlineCode parentName="p">{`b`}</inlineCode>{` to be in line with the `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` that
is expected of `}<inlineCode parentName="p">{`np.random.standard_normal`}</inlineCode>{` (`}<inlineCode parentName="p">{`NDArray[np.float64]`}</inlineCode>{`).`}</p>
    <h2 {...{
      "id": "shape-typing-numpy-arrays"
    }}>{`Shape typing `}<inlineCode parentName="h2">{`numpy`}</inlineCode>{` arrays`}</h2>
    <p>{`While `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` typing is great, it's not the most useful for preventing shape
errors (like from our original example).`}</p>
    <p>{`Ideally it would be great if in addition to a `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` type, we can also
include information about an `}<inlineCode parentName="p">{`ndarray`}</inlineCode>{`'s shape to do shape typing.`}</p>
    <p>{`Shape typing is a technique used to annotate information about the
dimensionality and size of an array. In the context of `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` and the
`}<inlineCode parentName="p">{`python`}</inlineCode>{` type hinting system, we can use shape typing catch shape errors
before runtime.`}</p>
    <blockquote>
      <p parentName="blockquote">{`For more information about shape typing checkout `}<a parentName="p" {...{
          "href": "https://docs.google.com/document/d/1But-hjet8-djv519HEKvBN6Ik2lW3yu0ojZo6pG9osY/edit#heading=h.aw3bt3fg1s2w"
        }}>{`this google doc on a shape
typing syntax proposal`}</a>{` by Matthew Rahtz, Jörg Bornschein, Vlad Mikulik, Tim
Harley, Matthew Willson, Dimitrios Vytiniotis, Sergei Lebedev, Adam Paszke.`}</p>
    </blockquote>
    <p>{`As we've seen, `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`'s `}<inlineCode parentName="p">{`NDArray`}</inlineCode>{` currently only supports `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` typing and
doesn't have any of this kind of shape typing ability. But why is that? If we
dig into the definition of the `}<inlineCode parentName="p">{`NDArray`}</inlineCode>{` type:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)

if TYPE_CHECKING or sys.version_info >= (3, 9):
    _DType = np.dtype[ScalarType]
    NDArray = np.ndarray[Any, np.dtype[ScalarType]]
else:
    _DType = _GenericAlias(np.dtype, (ScalarType,))
    NDArray = _GenericAlias(np.ndarray, (Any, _DType))
`}</code></pre>
    <p>{`And follow the definition of `}<inlineCode parentName="p">{`np.ndarray`}</inlineCode>{` ...`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
`}</code></pre>
    <p>{`We can see that it looks like `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` uses a `}<inlineCode parentName="p">{`Shape`}</inlineCode>{` type already! But
unfortunately if we look at the definition for this ...`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`# TODO: Set the \`bound\` to something more suitable once we
# have proper shape support
_ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
`}</code></pre>
    <p>{`😭 Looks like we're stuck with `}<inlineCode parentName="p">{`Any`}</inlineCode>{` which doesn't add any useful shape
information on our types.`}</p>
    <p>{`Luckily for us, we don't have to wait for shape support in `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`. `}<a parentName="p" {...{
        "href": "https://peps.python.org/pep-0646/"
      }}>{`PEP 646`}</a>{` has
the base foundation for shape typing and has already been accepted into `}<inlineCode parentName="p">{`python==3.11`}</inlineCode>{`! And it's supported by `}<inlineCode parentName="p">{`pyright`}</inlineCode>{`! Theoretically these two things give
us most of the ingredients to do basic shape typing.`}</p>
    <p>{`Now this blog post isn't about the details of `}<a parentName="p" {...{
        "href": "https://peps.python.org/pep-0646/"
      }}>{`PEP 646`}</a>{` or variadic
generics. Understanding PEP 646 will help, but it's not needed to understand
the rest of this post.`}</p>
    <p>{`In order to add rudimentary shape typing to `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` we can simply change the
`}<inlineCode parentName="p">{`Any`}</inlineCode>{` type in the `}<inlineCode parentName="p">{`NDArray`}</inlineCode>{` type definition to an unpacked variadic generic
like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ScalarType = TypeVar("ScalarType", bound=np.generic, covariant=True)
Shape = TypeVarTuple("Shape")

if TYPE_CHECKING or sys.version_info >= (3, 9):
    _DType = np.dtype[ScalarType]
    NDArray = np.ndarray[*Shape, np.dtype[ScalarType]]
else:
    _DType = _GenericAlias(np.dtype, (ScalarType,))
    NDArray = _GenericAlias(np.ndarray, (Any, _DType))
`}</code></pre>
    <p>{`Doing so allows us to fill in a `}<inlineCode parentName="p">{`Tuple`}</inlineCode>{` based type (indicating shape) in an
`}<inlineCode parentName="p">{`NDArray`}</inlineCode>{` alongside a `}<inlineCode parentName="p">{`dtype`}</inlineCode>{` type. And shape typing with `}<inlineCode parentName="p">{`Tuple`}</inlineCode>{`'s enables us
define function overloads which describe to a type checker the possible ways a
function can change the shape of an `}<inlineCode parentName="p">{`NDArray`}</inlineCode>{`.`}</p>
    <p>{`Let's look at an example of using these concepts to type a wrapper function
for `}<inlineCode parentName="p">{`np.random.standard_normal`}</inlineCode>{` from our `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` example with an intentional
type error:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np
from numpy.typing import NDArray
from typing import Tuple, TypeVar, Literal

# Generic dimension sizes types
T1 = TypeVar("T1", bound=int)
T2 = TypeVar("T2", bound=int)
T3 = TypeVar("T3", bound=int)

# Dimension types represented as typles
Shape = Tuple
Shape1D = Shape[T1]
Shape2D = Shape[T1, T2]
Shape3D = Shape[T1, T2, T3]
ShapeND = Shape[T1, ...]
ShapeNDType = TypeVar("ShapeNDType", bound=ShapeND)

def rand_normal_matrix(shape: ShapeNDType) -> NDArray[ShapeNDType, np.float64]:
    """Return a random ND normal matrix."""
    return np.random.standard_normal(size=shape)

# Yay correctly typed 2x2x2 cube!
LENGTH = Literal[2]
cube: NDArray[Shape3D[LENGTH, LENGTH, LENGTH], np.float64] = rand_normal_matrix((2,2,2))
print(cube)

SIDE = Literal[4]

# Uh oh the shapes won't match!
square: NDArray[Shape2D[SIDE, SIDE], np.float64] = rand_normal_matrix((3,3))
print(square)
`}</code></pre>
    <p>{`Notice here there are no `}<inlineCode parentName="p">{`assert`}</inlineCode>{` statements. And instead of several comments
about shape, we indicate shape in the type hint.`}</p>
    <p>{`Now while this code is "runnable", `}<inlineCode parentName="p">{`pyright`}</inlineCode>{` will tell us something else:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright bad_shape_typing.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/bad_shape_typing.py
  /mnt/bad_shape_typing.py:30:71 - error: Argument of type "tuple[Literal[3], Literal[3]]" cannot be assigned to parameter "shape" of type "ShapeNDType@rand_normal_matrix" in function "rand_normal_matrix"
    Type "Shape2D[SIDE, SIDE]" cannot be assigned to type "tuple[Literal[3], Literal[3]]" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
Completed in 0.535sec
`}</code></pre>
    <p><inlineCode parentName="p">{`pyright`}</inlineCode>{` is telling us we've incorrectly typed `}<inlineCode parentName="p">{`square`}</inlineCode>{` and that it's
incompatible with a `}<inlineCode parentName="p">{`3x3`}</inlineCode>{` shape. Now we know we need to go back and fix the
type to what a type checker should expect.`}</p>
    <p>{`Huzzah shape typing!!`}</p>
    <h2 {...{
      "id": "moar-numpy-shape-typing"
    }}>{`Moar `}<inlineCode parentName="h2">{`numpy`}</inlineCode>{` shape typing!`}</h2>
    <p>{`Now that we have shape typed one function, let's step it up a notch. Let's try
typing each `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` function in our `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` example to include shape
types. We've already typed `}<inlineCode parentName="p">{`np.random.standard_normal`}</inlineCode>{`, so next let's do
`}<inlineCode parentName="p">{`np.dot`}</inlineCode>{`.`}</p>
    <p>{`If we look at the `}<a parentName="p" {...{
        "href": "https://numpy.org/doc/stable/reference/generated/numpy.dot.html"
      }}>{`docs for `}<inlineCode parentName="a">{`np.dot`}</inlineCode></a>{` there are 5 type cases it supports.`}</p>
    <ol>
      <li parentName="ol">
        <p parentName="li">{`Both arguments as `}<inlineCode parentName="p">{`1D`}</inlineCode>{` arrays`}</p>
      </li>
      <li parentName="ol">
        <p parentName="li">{`Both arguments are `}<inlineCode parentName="p">{`2D`}</inlineCode>{` arrays (resulting in a `}<inlineCode parentName="p">{`matmul`}</inlineCode>{`)`}</p>
      </li>
      <li parentName="ol">
        <p parentName="li">{`Either arguments are scalars`}</p>
      </li>
      <li parentName="ol">
        <p parentName="li">{`Either argument is a `}<inlineCode parentName="p">{`ND`}</inlineCode>{` array and the other is a `}<inlineCode parentName="p">{`1D`}</inlineCode>{` array`}</p>
      </li>
      <li parentName="ol">
        <p parentName="li">{`One argument is `}<inlineCode parentName="p">{`ND`}</inlineCode>{` array and the other is `}<inlineCode parentName="p">{`MD`}</inlineCode>{` array`}</p>
      </li>
    </ol>
    <p>{`We can implement these cases as follows`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ShapeVarGen = TypeVarTuple("ShapeVarGen")

@overload
def dot(x1: NDArray[Shape1D[T1], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /) -> GenericDType:
    ...


@overload
def dot(
    x1: NDArray[Shape[T1, *ShapeVarGen], GenericDType], x2: NDArray[Shape1D[T1], GenericDType], /
) -> NDArray[Shape[*ShapeVarGen], GenericDType]:
    ...


@overload
def dot(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: NDArray[Shape2D[T2, T3], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T3], GenericDType]:
    ...


@overload
def dot(x1: GenericDType, x2: GenericDType, /) -> GenericDType:
    ...


def dot(x1, x2):
    return np.dot(x1, x2)

`}</code></pre>
    <p>{`The only case we can't implement is an `}<inlineCode parentName="p">{`ND`}</inlineCode>{` dimensional array with an `}<inlineCode parentName="p">{`MD`}</inlineCode>{`
dimensional array. Ideally we would try implementing it like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ShapeVarGen1 = TypeVarTuple("ShapeVarGen1")
ShapeVarGen2 = TypeVarTuple("ShapeVarGen2")

@overload
def dot(
    x1: NDArray[Shape[*ShapeVarGen1, T1], GenericDType], x2: NDArray[Shape[*ShapeVarGen2, T1, T2], GenericDType], /
) -> NDArray[Shape[*ShapeVarGen1, *ShapeVarGen2], GenericDType]:
    ...
`}</code></pre>
    <p>{`But currently using multiple type variable tuples `}<a parentName="p" {...{
        "href": "https://peps.python.org/pep-0646/#multiple-type-variable-tuples-not-allowed"
      }}>{`is not allowed`}</a>{`. If you know
of another way to cover this case let me know! Luckily for our `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` use
case, it only uses scalars, vectors, and matrices which is covered by our four
overloads.`}</p>
    <p>{`Here's how we would use these `}<inlineCode parentName="p">{`dot`}</inlineCode>{` overloads to do the dot product between a
`}<inlineCode parentName="p">{`2x3`}</inlineCode>{` matrix and a `}<inlineCode parentName="p">{`3x2`}</inlineCode>{` matrix with type hints:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np
from numpy.typing import NDArray
from numpy_shape_typing.dot import dot
from numpy_shape_typing.types import ShapeNDType, Shape2D
from numpy_shape_typing.rand import rand_normal_matrix

from typing import Literal

ROWS = Literal[2]
COLS = Literal[3]
A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))
B: NDArray[Shape2D[COLS, ROWS], np.float64] = rand_normal_matrix((3,2))
C: NDArray[Shape2D[ROWS, ROWS], np.float64] = dot(A, B)
print(C)
`}</code></pre>
    <p>{`And if we check with `}<inlineCode parentName="p">{`pyright`}</inlineCode>{`:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright good_dot.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
0 errors, 0 warnings, 0 informations
Completed in 0.909sec
`}</code></pre>
    <p>{`Everything looks good as it should!`}</p>
    <p>{`And if we change the types to invalid matrix shapes:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np
from numpy.typing import NDArray
from numpy_shape_typing.dot import dot
from numpy_shape_typing.rand import rand_normal_matrix
from numpy_shape_typing.types import ShapeNDType, Shape2D

from typing import Literal

ROWS = Literal[2]
COLS = Literal[3]
SLICES = Literal[4]

# uh oh based on these types we can't do a valid dot product!
A: NDArray[Shape2D[ROWS, COLS], np.float64] = rand_normal_matrix((2,3))
B: NDArray[Shape2D[SLICES, COLS], np.float64] = rand_normal_matrix((4,3))
C: NDArray[Shape2D[ROWS, COLS], np.float64] = dot(A, B)
print(C)
`}</code></pre>
    <p>{`And if we check with `}<inlineCode parentName="p">{`pyright`}</inlineCode>{`:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright ./bad_dot.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/bad_dot.py
  /mnt/bad_dot.py:16:54 - error: Argument of type "NDArray[Shape2D[SLICES, COLS], float64]" cannot be assigned to parameter "x2" of type "GenericDType@dot" in function "dot"
    Type "NDArray[Shape2D[ROWS, COLS], float64]" cannot be assigned to type "NDArray[Shape2D[SLICES, COLS], float64]" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
Completed in 0.908sec
`}</code></pre>
    <p><inlineCode parentName="p">{`pyright`}</inlineCode>{` let's us know that the types we are using are incorrect shapes based
on `}<inlineCode parentName="p">{`np.dot`}</inlineCode>{`'s type overloads we've specified.`}</p>
    <h2 {...{
      "id": "even-moar-numpy-shape-typing"
    }}>{`Even moar `}<inlineCode parentName="h2">{`numpy`}</inlineCode>{` shape typing!`}</h2>
    <p>{`The next function we are going to type is `}<inlineCode parentName="p">{`np.add`}</inlineCode>{`. The `}<a parentName="p" {...{
        "href": "https://numpy.org/doc/stable/reference/generated/numpy.add.html"
      }}><inlineCode parentName="a">{`numpy`}</inlineCode>{` docs`}</a>{` only show
two cases.`}</p>
    <ol>
      <li parentName="ol">
        <p parentName="li">{`Two `}<inlineCode parentName="p">{`ND`}</inlineCode>{` array arguments of the same shape are added element wise`}</p>
      </li>
      <li parentName="ol">
        <p parentName="li">{`Two `}<inlineCode parentName="p">{`ND`}</inlineCode>{` array arguments that are not the same shape must be broadcastable to
a common shape`}</p>
      </li>
    </ol>
    <p>{`Covering the first case is easy, but the second case is much harder as we
would have to come up with a scheme to cover `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`'s `}<a parentName="p" {...{
        "href": "https://numpy.org/doc/stable/user/basics.broadcasting.html"
      }}>{`array broadcasting
system`}</a>{`. Currently `}<inlineCode parentName="p">{`python==3.11`}</inlineCode>{`'s `}<inlineCode parentName="p">{`typing`}</inlineCode>{` doesn't have a generic way to
cover all the broadcasting rules. (If you know of a way let me know!)`}</p>
    <p>{`However if we scope down the second case to only two dimensions, we can cover
all the array broadcasting rules with a few overloads:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`from typing import overload

import numpy as np
from numpy.typing import NDArray

from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D, ShapeVarGen


@overload
def add(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: NDArray[Shape1D[T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape1D[T2], GenericDType],
    x2: NDArray[Shape2D[T1, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: NDArray[Shape1D[ONE], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape1D[ONE], GenericDType],
    x2: NDArray[Shape2D[T1, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: NDArray[Shape2D[T1, ONE], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: NDArray[Shape2D[ONE, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[T1, ONE], GenericDType],
    x2: NDArray[Shape2D[T1, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[ONE, T2], GenericDType],
    x2: NDArray[Shape2D[T1, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: GenericDType,
    x2: NDArray[Shape2D[T1, T2], GenericDType],
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[Shape2D[T1, T2], GenericDType],
    x2: GenericDType,
    /,
) -> NDArray[Shape2D[T1, T2], GenericDType]:
    ...


@overload
def add(
    x1: NDArray[*ShapeVarGen, GenericDType],
    x2: NDArray[*ShapeVarGen, GenericDType],
    /,
) -> NDArray[*ShapeVarGen, GenericDType]:
    ...


def add(x1, x2):
    return np.add(x1, x2)
`}</code></pre>
    <p>{`Using these overloads, here is how we would catch unexpected array broadcasts
(similar to the one from our original `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` example).`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`from typing import Literal

import numpy as np
from numpy.typing import NDArray

from numpy_shape_typing.add import add
from numpy_shape_typing.dot import dot
from numpy_shape_typing.rand import rand_normal_matrix
from numpy_shape_typing.types import ONE, Shape1D, Shape2D

COLS = Literal[4]
A: NDArray[Shape2D[COLS, COLS], np.float64] = rand_normal_matrix((4, 4))
B: NDArray[Shape2D[ONE, COLS], np.float64] = rand_normal_matrix((1, 4))
C: NDArray[Shape2D[ONE, COLS], np.float64] = add(A, B)
print(C)
`}</code></pre>
    <p>{`In the example above, our output is a `}<inlineCode parentName="p">{`4x4`}</inlineCode>{` matrix, but what we want from our
types is an output shape of `}<inlineCode parentName="p">{`4x1`}</inlineCode>{`. Let's see what `}<inlineCode parentName="p">{`pyright`}</inlineCode>{` says`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright unnexpected_broadcast.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/unnexpected_broadcast.py
  /mnt/unnexpected_broadcast.py:14:50 - error: Argument of type "NDArray[Shape2D[COLS, COLS], float64]" cannot be assigned to parameter "x1" of type "NDArray[*ShapeVarGen@add, GenericDType@add]" in function "add"
    "NDArray[Shape2D[COLS, COLS], float64]" is incompatible with "NDArray[Shape2D[ONE, COLS], float64]"
      TypeVar "_ShapeType@ndarray" is invariant
        "*tuple[Shape2D[COLS, COLS]]" is incompatible with "*tuple[Shape2D[ONE, COLS]]"
          Tuple entry 1 is incorrect type
            "Shape2D[COLS, COLS]" is incompatible with "Shape2D[ONE, COLS]" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
Completed in 2.757sec
`}</code></pre>
    <p><inlineCode parentName="p">{`pyright`}</inlineCode>{` informs us that our shapes are off and that we got broadcasted to a
`}<inlineCode parentName="p">{`4x4`}</inlineCode>{`! Huzzah shape typing!`}</p>
    <h2 {...{
      "id": "hitting-the-limitations-of-shape-typing-"
    }}>{`Hitting the limitations of shape typing 😿`}</h2>
    <p>{`The last function we will type to finish of our `}<inlineCode parentName="p">{`Linear`}</inlineCode>{` example is
`}<inlineCode parentName="p">{`np.ravel`}</inlineCode>{`. However this is where we start hitting some major limitations of
shape typing as they exist today in `}<inlineCode parentName="p">{`python`}</inlineCode>{` and `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`.`}</p>
    <p>{`From the `}<a parentName="p" {...{
        "href": "https://numpy.org/doc/stable/reference/generated/numpy.ravel.html"
      }}>{`numpy docs on`}</a>{` `}<inlineCode parentName="p">{`np.ravel`}</inlineCode>{` the only case we need to cover is that any
`}<inlineCode parentName="p">{`ND`}</inlineCode>{` array gets collapsed into a `}<inlineCode parentName="p">{`1D`}</inlineCode>{` array of size of the total number of
elements. Luckily all the information to compute the final `}<inlineCode parentName="p">{`1D`}</inlineCode>{` size is just
the product of all the input dimension sizes.`}</p>
    <p>{`Ideally we would try to write code that looks something like this:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ShapeVarGen = TypeVarTuple("ShapeVarGen")

@overload
def ravel(
    arr: NDArray[Shape[*ShapeVarGen], GenericDType]
) -> NDArray[Shape1D[Product[*ShapeVarGen]], GenericDType]:
    ...
`}</code></pre>
    <p>{`But unfortunately `}<inlineCode parentName="p">{`python`}</inlineCode>{`'s `}<inlineCode parentName="p">{`typing`}</inlineCode>{` package currently doesn't have a notion
of a `}<inlineCode parentName="p">{`Product`}</inlineCode>{` type that provides a way to do algebraic typing.`}</p>
    <p>{`However for the sake of completion we can fake it!`}</p>
    <p>{`If we scope down from a generic `}<inlineCode parentName="p">{`ND`}</inlineCode>{` typing of `}<inlineCode parentName="p">{`np.ravel`}</inlineCode>{` to support up to two
dimensions and limit the size of the output dimension to some maximum number,
we can overload all the possible factors that multiply to the output dimension
size. We would effectively be typing a multiplication table 😆, but it will
work and get us to a "partially" typed `}<inlineCode parentName="p">{`np.ravel`}</inlineCode>{`.`}</p>
    <p>{`Here's how we can do it.`}</p>
    <p>{`First we create a bunch of `}<inlineCode parentName="p">{`Literal`}</inlineCode>{` types (our factors):`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`ZERO = Literal[0]
ONE = Literal[1]
TWO = Literal[2]
THREE = Literal[3]
FOUR = Literal[4]
...
`}</code></pre>
    <p>{`Then we define "multiply" types for factor pairs of numbers:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`SHAPE_2D_MUL_TO_ONE = TypeVar(
    "SHAPE_2D_MUL_TO_ONE",
    bound=Shape2D[Literal[ONE], Literal[ONE]],
)
SHAPE_2D_MUL_TO_TWO = TypeVar(
    "SHAPE_2D_MUL_TO_TWO",
    bound=Union[Shape2D[Literal[ONE], Literal[TWO]], Shape2D[Literal[TWO], Literal[ONE]]],
)
SHAPE_2D_MUL_TO_THREE = TypeVar(
    "SHAPE_2D_MUL_TO_THREE",
    bound=Union[Shape2D[Literal[ONE], Literal[THREE]], Shape2D[Literal[THREE], Literal[ONE]]],
)
SHAPE_2D_MUL_TO_FOUR = TypeVar(
    "SHAPE_2D_MUL_TO_FOUR",
    bound=Union[
        Shape2D[Literal[ONE], Literal[FOUR]],
        Shape2D[Literal[TWO], Literal[TWO]],
        Shape2D[Literal[FOUR], Literal[ONE]],
    ],
)
`}</code></pre>
    <p>{`Then lastly we wire these types up into individual `}<inlineCode parentName="p">{`ravel`}</inlineCode>{` overloads (and
cover a few generic ones while we're at it):`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`@overload
def ravel(arr: NDArray[SHAPE_2D_MUL_TO_ONE, GenericDType]) -> NDArray[Shape1D[ONE], GenericDType]:
    ...


@overload
def ravel(arr: NDArray[SHAPE_2D_MUL_TO_TWO, GenericDType]) -> NDArray[Shape1D[TWO], GenericDType]:
    ...


@overload
def ravel(arr: NDArray[SHAPE_2D_MUL_TO_THREE, GenericDType]) -> NDArray[Shape1D[THREE], GenericDType]:
    ...


@overload
def ravel(arr: NDArray[SHAPE_2D_MUL_TO_FOUR, GenericDType]) -> NDArray[Shape1D[FOUR], GenericDType]:
    ...

@overload
def ravel(arr: NDArray[Shape2D[T1, ONE], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
    ...


@overload
def ravel(arr: NDArray[Shape2D[ONE, T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
    ...


@overload
def ravel(arr: NDArray[Shape1D[T1], GenericDType]) -> NDArray[Shape1D[T1], GenericDType]:
    ...
`}</code></pre>
    <p>{`Now we can rinse and repeat for as many numbers as we like!`}</p>
    <p>{`Here is how we'd use this typing to catch a shape type error with `}<inlineCode parentName="p">{`ravel`}</inlineCode>{`:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`import numpy as np
from numpy.typing import NDArray

from numpy_shape_typing.rand import rand_normal_matrix
from numpy_shape_typing.ravel import ravel
from numpy_shape_typing.types import FOUR, SEVEN, TWO, Shape1D, Shape2D

A: NDArray[Shape2D[TWO, FOUR], np.float64] = rand_normal_matrix((2, 4))
B: NDArray[Shape1D[SEVEN], np.float64] = ravel(A)
print(B)
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright raveling.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/raveling.py
  /mnt/raveling.py:9:42 - error: Expression of type "NDArray[Shape1D[EIGHT], float64]" cannot be assigned to declared type "NDArray[Shape1D[SEVEN], float64]"
    "NDArray[Shape1D[EIGHT], float64]" is incompatible with "NDArray[Shape1D[SEVEN], float64]"
      TypeVar "_ShapeType@ndarray" is invariant
        "*tuple[Shape1D[EIGHT]]" is incompatible with "*tuple[Shape1D[SEVEN]]"
          Tuple entry 1 is incorrect type
            "Shape1D[EIGHT]" is incompatible with "Shape1D[SEVEN]" (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
Completed in 0.933sec
`}</code></pre>
    <h2 {...{
      "id": "putting-it-all-together"
    }}>{`Putting it all together`}</h2>
    <p>{`So far we've gone through typing a small subset of `}<inlineCode parentName="p">{`numpy`}</inlineCode>{`'s functions
(`}<inlineCode parentName="p">{`np.random.standard_normal`}</inlineCode>{`, `}<inlineCode parentName="p">{`np.dot`}</inlineCode>{`, `}<inlineCode parentName="p">{`np.add`}</inlineCode>{`, and `}<inlineCode parentName="p">{`np.ravel`}</inlineCode>{` in all).`}</p>
    <p>{`Now we can chain these typed functions together to form a typed `}<inlineCode parentName="p">{`Linear`}</inlineCode>{`
implementation like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`from typing import Literal

import numpy as np
from numpy.typing import NDArray

from numpy_shape_typing.add import add
from numpy_shape_typing.dot import dot
from numpy_shape_typing.rand import rand_normal_matrix
from numpy_shape_typing.ravel import ravel
from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D


def Linear(
    A: NDArray[Shape2D[T1, T2], GenericDType],
    x: NDArray[Shape2D[T2, ONE], GenericDType],
    b: NDArray[Shape2D[T1, ONE], GenericDType],
) -> NDArray[Shape1D[T1], GenericDType]:
    Ax = dot(A, x)
    Axb = add(Ax, b)
    return ravel(Axb)


IN_DIM = Literal[3]
in_dim: IN_DIM = 3

OUT_DIM = Literal[4]
out_dim: OUT_DIM = 4

# bad type >:(
BAD_OUT_DIM = Literal[5]

A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))
x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))
b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))

# this is a bad type!
y: NDArray[Shape1D[BAD_OUT_DIM], np.float64] = Linear(A, x, b)
`}</code></pre>
    <p>{`I've included an intentional type error which should be caught by `}<inlineCode parentName="p">{`pyright`}</inlineCode>{`
like so:`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright linear_type_bad.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
/mnt/linear_type_bad.py
  /mnt/linear_type_bad.py:37:55 - error: Argument of type "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" cannot be assigned to parameter "A" of type "NDArray[Shape2D[T1@Linear, T2@Linear], GenericDType@Linear]" in function "Linear"
    "NDArray[Shape2D[OUT_DIM, IN_DIM], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, IN_DIM], float64]"
      TypeVar "_ShapeType@ndarray" is invariant
        "*tuple[Shape2D[OUT_DIM, IN_DIM]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, IN_DIM]]"
          Tuple entry 1 is incorrect type
            "Shape2D[OUT_DIM, IN_DIM]" is incompatible with "Shape2D[BAD_OUT_DIM, IN_DIM]" (reportGeneralTypeIssues)
  /mnt/linear_type_bad.py:37:61 - error: Argument of type "NDArray[Shape2D[OUT_DIM, ONE], float64]" cannot be assigned to parameter "b" of type "NDArray[Shape2D[T1@Linear, ONE], GenericDType@Linear]" in function "Linear"
    "NDArray[Shape2D[OUT_DIM, ONE], float64]" is incompatible with "NDArray[Shape2D[BAD_OUT_DIM, ONE], float64]"
      TypeVar "_ShapeType@ndarray" is invariant
        "*tuple[Shape2D[OUT_DIM, ONE]]" is incompatible with "*tuple[Shape2D[BAD_OUT_DIM, ONE]]"
          Tuple entry 1 is incorrect type
            "Shape2D[OUT_DIM, ONE]" is incompatible with "Shape2D[BAD_OUT_DIM, ONE]" (reportGeneralTypeIssues)
2 errors, 0 warnings, 0 informations
Completed in 8.155sec
`}</code></pre>
    <p>{`And huzzah again! `}<inlineCode parentName="p">{`pyright`}</inlineCode>{` has caught the shape type error!`}</p>
    <p>{`And now we can fix this shape error by changing `}<inlineCode parentName="p">{`BAD_OUT_DIM`}</inlineCode>{` to the correct
output dimension size.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-python"
      }}>{`from typing import Literal

import numpy as np
from numpy.typing import NDArray

from numpy_shape_typing.add import add
from numpy_shape_typing.dot import dot
from numpy_shape_typing.rand import rand_normal_matrix
from numpy_shape_typing.ravel import ravel
from numpy_shape_typing.types import ONE, T1, T2, GenericDType, Shape1D, Shape2D


def Linear(
    A: NDArray[Shape2D[T1, T2], GenericDType],
    x: NDArray[Shape2D[T2, ONE], GenericDType],
    b: NDArray[Shape2D[T1, ONE], GenericDType],
) -> NDArray[Shape1D[T1], GenericDType]:
    """
    Args:
        A: ndarray of shape (M x N)
        x: ndarray of shape (N x 1)
        b: ndarray of shape (M x 1)

    Returns:
        Linear output ndarray of shape (M)
    """
    Ax = dot(A, x)
    Axb = add(Ax, b)
    return ravel(Axb)


IN_DIM = Literal[3]
in_dim: IN_DIM = 3

OUT_DIM = Literal[4]
out_dim: OUT_DIM = 4

A: NDArray[Shape2D[OUT_DIM, IN_DIM], np.float64] = rand_normal_matrix((out_dim, in_dim))
x: NDArray[Shape2D[IN_DIM, ONE], np.float64] = rand_normal_matrix((in_dim, 1))
b: NDArray[Shape2D[OUT_DIM, ONE], np.float64] = rand_normal_matrix((out_dim, 1))
y: NDArray[Shape1D[OUT_DIM], np.float64] = Linear(A, x, b)
`}</code></pre>
    <p>{`And if we check with `}<inlineCode parentName="p">{`pyright`}</inlineCode>{`.`}</p>
    <pre><code parentName="pre" {...{
        "className": "language-bash"
      }}>{`py -m pyright linear_type_good.py --lib
`}</code></pre>
    <pre><code parentName="pre" {...{
        "className": "language-text"
      }}>{`No configuration file found.
No pyproject.toml file found.
Assuming Python platform Linux
Searching for source files
Found 1 source file
pyright 1.1.299
0 errors, 0 warnings, 0 informations
Completed in 8.116sec
`}</code></pre>
    <p><inlineCode parentName="p">{`pyright`}</inlineCode>{` tells us that our types are consistent!`}</p>
    <h2 {...{
      "id": "whats-next"
    }}>{`What's next?`}</h2>
    <p>{`You tell me! Many open source scientific computing libraries have GitHub issues
about shape typing such as:`}</p>
    <ul>
      <li parentName="ul"><inlineCode parentName="li">{`numpy`}</inlineCode>{`: `}<a parentName="li" {...{
          "href": "https://github.com/numpy/numpy/issues/16544"
        }}>{`https://github.com/numpy/numpy/issues/16544`}</a></li>
      <li parentName="ul"><inlineCode parentName="li">{`jax`}</inlineCode>{`: `}<a parentName="li" {...{
          "href": "https://github.com/google/jax/issues/12049"
        }}>{`https://github.com/google/jax/issues/12049`}</a></li>
      <li parentName="ul"><inlineCode parentName="li">{`pytorch`}</inlineCode>{`: `}<a parentName="li" {...{
          "href": "https://github.com/pytorch/pytorch/issues/33953"
        }}>{`https://github.com/pytorch/pytorch/issues/33953`}</a></li>
    </ul>
    <p>{`So it's well recognized as a desirable feature. Some of the major technical
hurdles we still need to overcome are:`}</p>
    <ul>
      <li parentName="ul"><a parentName="li" {...{
          "href": "https://github.com/python/mypy/issues/12280"
        }}>{`PEP 646 in mypy`}</a></li>
      <li parentName="ul"><a parentName="li" {...{
          "href": "https://github.com/python/mypy/issues/11990"
        }}>{`Arithmetic between Literal's`}</a></li>
      <li parentName="ul"><a parentName="li" {...{
          "href": "https://peps.python.org/pep-0646/#multiple-type-variable-tuples-not-allowed"
        }}>{`Multiple type variable tuples`}</a></li>
      <li parentName="ul"><a parentName="li" {...{
          "href": "https://peps.python.org/pep-0646/#variance-type-constraints-and-type-bounds-not-yet-supported"
        }}>{`Type bounds for variadic generics`}</a></li>
    </ul>
    <p>{`Once these hurdles are overcome I don't see any blockers stopping projects
like `}<inlineCode parentName="p">{`numpy`}</inlineCode>{` from being fully shape typed.`}</p>
    <p>{`This post and `}<a parentName="p" {...{
        "href": "https://github.com/cmrfrd/numpy_shape_typing"
      }}>{`accompanying repo`}</a>{` is just a sample form of what shape typing
might become. With future PEPs and work on the `}<inlineCode parentName="p">{`python`}</inlineCode>{` type hinting system,
we'll hopefully make our code incrementally safer.`}</p>
    <p>{`Thanks for reading! (っ◔◡◔)っ ♥`}</p>

    </MDXLayout>;
}
;
MDXContent.isMDXComponent = true;
      