File size: 639 Bytes
29547e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import Optional
import torch
from ._ops import ops

def gemm(a: torch.Tensor, b: torch.Tensor, as_: torch.Tensor, bs: torch.Tensor, 
         out: Optional[torch.Tensor] = None) -> torch.Tensor:
         
    if out is None:
        # Create output tensor with appropriate shape and dtype
        M, K = a.shape
        K_b, N = b.shape
        assert K == K_b, f"Matrix dimension mismatch: A has {K} cols, B has {K_b} rows"
        
        # Output should be BF16 type on the same device as inputs
        out = torch.empty((M, N), dtype=torch.bfloat16, device=a.device)
    
    ops.gemm(out, a, b, as_, bs)
    return out