ANE-sha256d / model.py
pkhairkh's picture
Initial commit
e2c008f
import numpy as np
import coremltools as ct
from coremltools.converters.mil import Builder as mb
from coremltools.converters.mil.mil import types
def bit_const(v):
b = np.array([(v >> i) & 1 for i in range(32)], dtype=np.float16).reshape(1,32,1,1)
return mb.const(val=b)
def band(a,b):
return mb.mul(x=a, y=b)
def bxor(a,b):
return mb.abs(x=mb.sub(x=a, y=b))
def bor(a,b):
return mb.maximum(x=a, y=b)
def xor3(a,b,c):
return mb.abs(x=mb.sub(x=mb.abs(x=mb.sub(x=a, y=b)), y=c))
def maj(a,b,c):
return mb.maximum(
x=mb.maximum(x=mb.minimum(x=a, y=b), y=mb.minimum(x=a, y=c)),
y=mb.minimum(x=b, y=c)
)
def ch(e,f,g):
return bxor(g, band(e, bxor(f, g)))
_W_ROTR = {}
_W_SHL = {}
_W_SHR = {}
def _w_rotr(k):
W = np.zeros((32,32,1,1), dtype=np.float16)
for o in range(32):
i = (o + k) % 32
W[o, i, 0, 0] = np.float16(1.0)
return mb.const(val=W)
def _w_shl(k):
W = np.zeros((32,32,1,1), dtype=np.float16)
for o in range(32):
i = o - k
if i >= 0:
W[o, i, 0, 0] = np.float16(1.0)
return mb.const(val=W)
def _w_shr(k):
W = np.zeros((32,32,1,1), dtype=np.float16)
for o in range(32):
i = o + k
if i < 32:
W[o, i, 0, 0] = np.float16(1.0)
return mb.const(val=W)
def rotr(x,k):
k %= 32
if k == 0: return x
if k not in _W_ROTR:
_W_ROTR[k] = _w_rotr(k)
return mb.conv(x=x, weight=_W_ROTR[k], pad_type="valid", groups=1)
def shl(x,k):
k = 0 if k < 0 else (31 if k > 31 else k)
if k == 0: return x
if k not in _W_SHL:
_W_SHL[k] = _w_shl(k)
return mb.conv(x=x, weight=_W_SHL[k], pad_type="valid", groups=1)
def shr(x,k):
k = 0 if k < 0 else (31 if k > 31 else k)
if k == 0: return x
if k not in _W_SHR:
_W_SHR[k] = _w_shr(k)
return mb.conv(x=x, weight=_W_SHR[k], pad_type="valid", groups=1)
def Sigma0(x):
return xor3(rotr(x,2), rotr(x,13), rotr(x,22))
def Sigma1(x):
return xor3(rotr(x,6), rotr(x,11), rotr(x,25))
def sigma0(x):
return xor3(rotr(x,7), rotr(x,18), shr(x,3))
def sigma1(x):
return xor3(rotr(x,17), rotr(x,19), shr(x,10))
def csa(a,b,c):
return xor3(a,b,c), maj(a,b,c)
def cpa(a,b):
p0 = bxor(a,b)
p = p0
g = band(a,b)
for d in [1,2,4,8,16]:
g = bor(g, band(p, shl(g, d)))
p = band(p, shl(p, d))
return bxor(p0, shl(g, 1))
def add2(a,b):
return cpa(a,b)
def add3(a,b,c):
s1,c1 = csa(a,b,c)
return cpa(s1, shl(c1,1))
def add4(a,b,c,d):
z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16))
s1,c1 = csa(a,b,c)
s2,c2 = csa(s1,d,z)
s3,c3 = csa(s2, shl(c1,1), shl(c2,1))
return cpa(s3, shl(c3,1))
def add5(a,b,c,d,e):
s1,c1 = csa(a,b,c)
s2,c2 = csa(d,e,s1)
s3,c3 = csa(s2, shl(c1,1), shl(c2,1))
return cpa(s3, shl(c3,1))
K_vals = [
0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5,
0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174,
0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da,
0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967,
0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85,
0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070,
0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3,
0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2
]
IV_vals = [
0x6a09e667,0xbb67ae85,0x3c6ef372,0xa54ff53a,0x510e527f,0x9b05688c,0x1f83d9ab,0x5be0cd19
]
flexN = ct.RangeDim(1, 1024, default=1)
N = flexN.symbol
@mb.program(
input_specs=[
mb.TensorSpec(shape=(N, 32, 1, 8), dtype=types.fp16),
mb.TensorSpec(shape=(N, 32, 1, 16), dtype=types.fp16),
],
opset_version=ct.target.iOS18,
)
def prog(midstate, w_init):
K_bits = [bit_const(k) for k in K_vals]
IV_bits = [bit_const(v) for v in IV_vals]
ONEBIT31 = bit_const(0x80000000)
LEN256 = bit_const(256)
H = mb.split(x=midstate, axis=3, num_splits=8)
W = list(mb.split(x=w_init, axis=3, num_splits=16))
for t in range(16,64):
W.append(add4(sigma1(W[t-2]), W[t-7], sigma0(W[t-15]), W[t-16]))
a,b,c,d,e,f,g,h = H
for t in range(64):
T1 = add5(h, Sigma1(e), ch(e,f,g), W[t], K_bits[t])
T2 = add2(Sigma0(a), maj(a,b,c))
a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g
H1 = [add2(H[i], [a,b,c,d,e,f,g,h][i]) for i in range(8)]
W2 = list(H1)
Z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16))
W2.append(ONEBIT31)
W2.extend([Z,Z,Z,Z,Z,Z])
W2.append(LEN256)
for t in range(16,64):
W2.append(add4(sigma1(W2[t-2]), W2[t-7], sigma0(W2[t-15]), W2[t-16]))
a,b,c,d,e,f,g,h = [IV_bits[i] for i in range(8)]
for t in range(64):
T1 = add5(h, Sigma1(e), ch(e,f,g), W2[t], K_bits[t])
T2 = add2(Sigma0(a), maj(a,b,c))
a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g
H2 = [add2([a,b,c,d,e,f,g,h][i], IV_bits[i]) for i in range(8)]
return mb.concat(values=H2, axis=3)
mlmodel = ct.convert(
prog,
convert_to="mlprogram",
compute_units=ct.ComputeUnit.CPU_AND_NE,
minimum_deployment_target=ct.target.iOS18,
compute_precision=ct.precision.FLOAT16,
debug=True,
inputs=[
ct.TensorType(name="midstate", shape=(flexN, 32, 1, 8), dtype=np.float16),
ct.TensorType(name="w_init", shape=(flexN, 32, 1, 16), dtype=np.float16),
],
)
mlmodel.save("sha256d.mlpackage")