| from typing import Optional | |
| import torch | |
| from ._ops import ops | |
| def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor, | |
| out: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| if out is None: | |
| # Create output tensor with appropriate shape and dtype | |
| M, K = a.shape | |
| K_b, N = b.shape | |
| assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows" | |
| # Output should be BF16 type on the same device as inputs | |
| out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device) | |
| ops.gemm(out, a, b, as_, bs) | |
| return out | |