|
|
import numpy as np |
|
|
import coremltools as ct |
|
|
from coremltools.converters.mil import Builder as mb |
|
|
from coremltools.converters.mil.mil import types |
|
|
|
|
|
def bit_const(v): |
|
|
b = np.array([(v >> i) & 1 for i in range(32)], dtype=np.float16).reshape(1,32,1,1) |
|
|
return mb.const(val=b) |
|
|
|
|
|
def band(a,b): |
|
|
return mb.mul(x=a, y=b) |
|
|
|
|
|
def bxor(a,b): |
|
|
return mb.abs(x=mb.sub(x=a, y=b)) |
|
|
|
|
|
def bor(a,b): |
|
|
return mb.maximum(x=a, y=b) |
|
|
|
|
|
def xor3(a,b,c): |
|
|
return mb.abs(x=mb.sub(x=mb.abs(x=mb.sub(x=a, y=b)), y=c)) |
|
|
|
|
|
def maj(a,b,c): |
|
|
return mb.maximum( |
|
|
x=mb.maximum(x=mb.minimum(x=a, y=b), y=mb.minimum(x=a, y=c)), |
|
|
y=mb.minimum(x=b, y=c) |
|
|
) |
|
|
|
|
|
def ch(e,f,g): |
|
|
return bxor(g, band(e, bxor(f, g))) |
|
|
|
|
|
_W_ROTR = {} |
|
|
_W_SHL = {} |
|
|
_W_SHR = {} |
|
|
|
|
|
def _w_rotr(k): |
|
|
W = np.zeros((32,32,1,1), dtype=np.float16) |
|
|
for o in range(32): |
|
|
i = (o + k) % 32 |
|
|
W[o, i, 0, 0] = np.float16(1.0) |
|
|
return mb.const(val=W) |
|
|
|
|
|
def _w_shl(k): |
|
|
W = np.zeros((32,32,1,1), dtype=np.float16) |
|
|
for o in range(32): |
|
|
i = o - k |
|
|
if i >= 0: |
|
|
W[o, i, 0, 0] = np.float16(1.0) |
|
|
return mb.const(val=W) |
|
|
|
|
|
def _w_shr(k): |
|
|
W = np.zeros((32,32,1,1), dtype=np.float16) |
|
|
for o in range(32): |
|
|
i = o + k |
|
|
if i < 32: |
|
|
W[o, i, 0, 0] = np.float16(1.0) |
|
|
return mb.const(val=W) |
|
|
|
|
|
def rotr(x,k): |
|
|
k %= 32 |
|
|
if k == 0: return x |
|
|
if k not in _W_ROTR: |
|
|
_W_ROTR[k] = _w_rotr(k) |
|
|
return mb.conv(x=x, weight=_W_ROTR[k], pad_type="valid", groups=1) |
|
|
|
|
|
def shl(x,k): |
|
|
k = 0 if k < 0 else (31 if k > 31 else k) |
|
|
if k == 0: return x |
|
|
if k not in _W_SHL: |
|
|
_W_SHL[k] = _w_shl(k) |
|
|
return mb.conv(x=x, weight=_W_SHL[k], pad_type="valid", groups=1) |
|
|
|
|
|
def shr(x,k): |
|
|
k = 0 if k < 0 else (31 if k > 31 else k) |
|
|
if k == 0: return x |
|
|
if k not in _W_SHR: |
|
|
_W_SHR[k] = _w_shr(k) |
|
|
return mb.conv(x=x, weight=_W_SHR[k], pad_type="valid", groups=1) |
|
|
|
|
|
def Sigma0(x): |
|
|
return xor3(rotr(x,2), rotr(x,13), rotr(x,22)) |
|
|
|
|
|
def Sigma1(x): |
|
|
return xor3(rotr(x,6), rotr(x,11), rotr(x,25)) |
|
|
|
|
|
def sigma0(x): |
|
|
return xor3(rotr(x,7), rotr(x,18), shr(x,3)) |
|
|
|
|
|
def sigma1(x): |
|
|
return xor3(rotr(x,17), rotr(x,19), shr(x,10)) |
|
|
|
|
|
def csa(a,b,c): |
|
|
return xor3(a,b,c), maj(a,b,c) |
|
|
|
|
|
def cpa(a,b): |
|
|
p0 = bxor(a,b) |
|
|
p = p0 |
|
|
g = band(a,b) |
|
|
for d in [1,2,4,8,16]: |
|
|
g = bor(g, band(p, shl(g, d))) |
|
|
p = band(p, shl(p, d)) |
|
|
return bxor(p0, shl(g, 1)) |
|
|
|
|
|
def add2(a,b): |
|
|
return cpa(a,b) |
|
|
|
|
|
def add3(a,b,c): |
|
|
s1,c1 = csa(a,b,c) |
|
|
return cpa(s1, shl(c1,1)) |
|
|
|
|
|
def add4(a,b,c,d): |
|
|
z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16)) |
|
|
s1,c1 = csa(a,b,c) |
|
|
s2,c2 = csa(s1,d,z) |
|
|
s3,c3 = csa(s2, shl(c1,1), shl(c2,1)) |
|
|
return cpa(s3, shl(c3,1)) |
|
|
|
|
|
def add5(a,b,c,d,e): |
|
|
s1,c1 = csa(a,b,c) |
|
|
s2,c2 = csa(d,e,s1) |
|
|
s3,c3 = csa(s2, shl(c1,1), shl(c2,1)) |
|
|
return cpa(s3, shl(c3,1)) |
|
|
|
|
|
K_vals = [ |
|
|
0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, |
|
|
0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, |
|
|
0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, |
|
|
0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, |
|
|
0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, |
|
|
0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, |
|
|
0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, |
|
|
0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2 |
|
|
] |
|
|
|
|
|
IV_vals = [ |
|
|
0x6a09e667,0xbb67ae85,0x3c6ef372,0xa54ff53a,0x510e527f,0x9b05688c,0x1f83d9ab,0x5be0cd19 |
|
|
] |
|
|
|
|
|
flexN = ct.RangeDim(1, 1024, default=1) |
|
|
N = flexN.symbol |
|
|
|
|
|
@mb.program( |
|
|
input_specs=[ |
|
|
mb.TensorSpec(shape=(N, 32, 1, 8), dtype=types.fp16), |
|
|
mb.TensorSpec(shape=(N, 32, 1, 16), dtype=types.fp16), |
|
|
], |
|
|
opset_version=ct.target.iOS18, |
|
|
) |
|
|
def prog(midstate, w_init): |
|
|
K_bits = [bit_const(k) for k in K_vals] |
|
|
IV_bits = [bit_const(v) for v in IV_vals] |
|
|
ONEBIT31 = bit_const(0x80000000) |
|
|
LEN256 = bit_const(256) |
|
|
|
|
|
H = mb.split(x=midstate, axis=3, num_splits=8) |
|
|
W = list(mb.split(x=w_init, axis=3, num_splits=16)) |
|
|
for t in range(16,64): |
|
|
W.append(add4(sigma1(W[t-2]), W[t-7], sigma0(W[t-15]), W[t-16])) |
|
|
a,b,c,d,e,f,g,h = H |
|
|
for t in range(64): |
|
|
T1 = add5(h, Sigma1(e), ch(e,f,g), W[t], K_bits[t]) |
|
|
T2 = add2(Sigma0(a), maj(a,b,c)) |
|
|
a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g |
|
|
H1 = [add2(H[i], [a,b,c,d,e,f,g,h][i]) for i in range(8)] |
|
|
W2 = list(H1) |
|
|
Z = mb.const(val=np.zeros((1,32,1,1), dtype=np.float16)) |
|
|
W2.append(ONEBIT31) |
|
|
W2.extend([Z,Z,Z,Z,Z,Z]) |
|
|
W2.append(LEN256) |
|
|
for t in range(16,64): |
|
|
W2.append(add4(sigma1(W2[t-2]), W2[t-7], sigma0(W2[t-15]), W2[t-16])) |
|
|
a,b,c,d,e,f,g,h = [IV_bits[i] for i in range(8)] |
|
|
for t in range(64): |
|
|
T1 = add5(h, Sigma1(e), ch(e,f,g), W2[t], K_bits[t]) |
|
|
T2 = add2(Sigma0(a), maj(a,b,c)) |
|
|
a,b,c,d,e,f,g,h = add2(T1,T2), a, b, c, add2(d,T1), e, f, g |
|
|
H2 = [add2([a,b,c,d,e,f,g,h][i], IV_bits[i]) for i in range(8)] |
|
|
return mb.concat(values=H2, axis=3) |
|
|
|
|
|
mlmodel = ct.convert( |
|
|
prog, |
|
|
convert_to="mlprogram", |
|
|
compute_units=ct.ComputeUnit.CPU_AND_NE, |
|
|
minimum_deployment_target=ct.target.iOS18, |
|
|
compute_precision=ct.precision.FLOAT16, |
|
|
debug=True, |
|
|
inputs=[ |
|
|
ct.TensorType(name="midstate", shape=(flexN, 32, 1, 8), dtype=np.float16), |
|
|
ct.TensorType(name="w_init", shape=(flexN, 32, 1, 16), dtype=np.float16), |
|
|
], |
|
|
) |
|
|
|
|
|
mlmodel.save("sha256d.mlpackage") |