Commit e6b6469c authored by Lena Heimberger's avatar Lena Heimberger
Browse files

SPHINCS+ java implementation

parent 2b17b2a3
# JavaSphincsPlus
This repository contains a complete Java implementation of [SPHINCS+](https://sphincs.org/) for all possible instantiations.
## Motivation
The main motivation for this project was to write a high-level implementation of the NIST PQ
Competition candidate SPHINCS+, a post-quantum secure hash-based digital signature scheme. The code is optimized for readability and straightforward usability. It aims to be faithful to the pseudocode in the [specification](https://sphincs.org/data/sphincs+-specification.pdf) to aid understanding of the signature scheme.
## Performance
Averaged over 100 iterations with random message input, the following signature and verification duration was measured.
### todo while computer at rest
## Used Libraries
Haraka is implemented without any dependencies. Please be aware that, if you are using a CPU without AES-NI hardware instructions, Haraka is susceptible to side-channel attacks.
For SHA256 and SHAKE256, a separate implementation is needed.
The implementation was tested using the JCE provided by [IAIK](https://jce.iaik.tugraz.at/sic/Products/Core_Crypto_Toolkits/JCA_JCE), but any other (correct) implementation should work.
## Features
In addition to the SPHINCS+ code, an Java JNI integration for [Haraka](https://github.com/kste/haraka) is included.
## How to use
Extensive examples can be found in the [test](https://extgit.iaik.tugraz.at/krypto/javasphincsplus/test) folder.
Generally, the Java JCE interface for
[signatures](http://javadoc.iaik.tugraz.at/jce_me/current/iaik/me/security/Signature.html) is used:
- initialization step
- ```sphincs.initSign()``` supply a private key to generate a signature
- ```sphincs.initVerify()``` supply a public key to verify a signature
- update step
- ```sphincs.update(data, begin, length)``` with the desired data
- finalization step
- ```sphincs.sign()``` to create a signature
- ```sphincs.verify()``` to verify a signature
## Related
**SPHINCS+** [website](https://sphincs.org/)
**Submission C code** on [Github](https://github.com/sphincs/sphincsplus)
**Haraka** with AESNI on [Github](https://github.com/kste/haraka)
## License
The code is licensed under the [MIT](https://choosealicense.com/licenses/mit/) license.
package at.iaik.pq.sphincs;
import at.iaik.pq.sphincs.fors.Fors;
import at.iaik.pq.sphincs.fors.ForsPair;
import at.iaik.pq.sphincs.hypertree.HypertreeSignature;
import at.iaik.pq.sphincs.keys.SphincsPrivateKey;
import at.iaik.pq.sphincs.keys.SphincsPublicKey;
import at.iaik.pq.sphincs.treeAddr.ForsTreeAdrs;
import at.iaik.pq.sphincs.treeAddr.HashTreeAdrs;
import at.iaik.pq.sphincs.utils.HashUtils;
import at.iaik.pq.sphincs.utils.SphincsParams;
import at.iaik.pq.sphincs.wots.wotshash.SphincsDigestAndIndex;
import at.iaik.pq.sphincs.wots.wotshash.WotsHash;
import iaik.asn1.structures.AlgorithmID;
import java.io.ByteArrayOutputStream;
import java.security.*;
import java.security.spec.AlgorithmParameterSpec;
import static at.iaik.pq.sphincs.utils.HashUtils.getHash;
public final class SphincsPlus extends SignatureSpi {
public static final AlgorithmID sphincsPlus = new AlgorithmID("1.3.6.1.4.1.2706.5.2", "sphincsplus", "SPHINCS+", false);
private SphincsParams sphincsParams;
private ByteArrayOutputStream msg = new ByteArrayOutputStream();
private int n;
private int a;
private WotsHash wotsHash;
private SphincsPublicKey publicKey;
private SphincsPrivateKey privateKey;
public SphincsPlus() {
}
@Override
protected void engineInitVerify(PublicKey publicKey) throws InvalidKeyException {
initVerify(publicKey);
}
public void initVerify(PublicKey publicKey) throws InvalidKeyException {
if (publicKey == null) {
throw new NullPointerException("Public key is null");
}
if (!(publicKey instanceof SphincsPublicKey))
throw new InvalidKeyException("Not a SPHINCS+ public key.");
this.publicKey = (SphincsPublicKey) publicKey;
this.sphincsParams.setPkRoot(this.publicKey.getPkRoot());
this.sphincsParams.setPkSeed(this.publicKey.getPkSeed());
wotsHash = WotsHash.algorithm(HashUtils.getHash(sphincsParams.getHashFunction()));
wotsHash.init(this.sphincsParams);
}
@Override
protected void engineInitSign(PrivateKey privateKey) throws InvalidKeyException {
initSign(privateKey);
}
public void initSign(PrivateKey privateKey) throws InvalidKeyException {
if (privateKey instanceof SphincsPrivateKey) {
this.privateKey = (SphincsPrivateKey) privateKey;
sphincsParams = new SphincsParams(this.privateKey.getSphincsParamSet(), this.privateKey);
sphincsParams.setSkSeed(this.privateKey.getSkSeed());
sphincsParams.setSkPrf(this.privateKey.getSkPrf());
sphincsParams.setPkSeed(this.privateKey.getPkSeed());
sphincsParams.setPkRoot(this.privateKey.getPkRoot());
wotsHash = WotsHash.algorithm(HashUtils.getHash(sphincsParams.getHashFunction()));
wotsHash.init(this.sphincsParams);
this.n = sphincsParams.getN();
this.a = sphincsParams.getA();
} else throw new InvalidKeyException("Not a SPHINCS+ private key");
}
@Override
protected void engineUpdate(byte b) throws SignatureException {
msg.write(b);
}
@Override
protected void engineUpdate(byte[] bytes, int i, int i1) throws SignatureException {
msg.write(bytes, i, i1);
}
public void update(byte[] bytes, int i, int i1) {
msg.write(bytes, i, i1);
}
public void reset() {
msg.reset();
}
@Override
protected byte[] engineSign() throws SignatureException {
return sign();
}
public byte[] sign() {
SphincsSignature sig = new SphincsSignature();
byte[] msgBytes = msg.toByteArray();
byte[] randomness;
byte[] digest;
//ForsTreeAdrs forsTreeAdrs = new ForsTreeAdrs();
ForsTreeAdrs forsTreeAdrs = new ForsTreeAdrs();
Fors fors = new Fors();
HypertreeSignature hypertreeSignature = new HypertreeSignature();
SphincsDigestAndIndex sphincsDigestAndIndex;
fors.setParameter(this.sphincsParams);
hypertreeSignature.setParameter(this.sphincsParams);
// compute and set randomness
byte[] opt = new byte[n]; // TODO possibly generate randomness for non-deterministic signatures
randomness = wotsHash.PRF_msg(msgBytes, opt);
sig.setRandomness(randomness);
//hash message and derive leaf index form public key, message and randomness
sphincsDigestAndIndex = wotsHash.Hmsg(msgBytes, randomness);
digest = sphincsDigestAndIndex.getDigest();
// compute FORS signature
forsTreeAdrs.setLayerAddress(0);
forsTreeAdrs.setTreeAddress(sphincsDigestAndIndex.getIdxTree());
forsTreeAdrs.setKeyPairAddress(sphincsDigestAndIndex.getIdxLeaf());
sig.setForsSignature(fors.forsSign(digest, forsTreeAdrs));
//get FORS public key
byte[] forsPk = fors.forsPkFromSig(sig.getForsSignature(), a, digest, forsTreeAdrs);
//sign the FORS public key with a hypertree signature
HashTreeAdrs hashTreeAdrs = new HashTreeAdrs();
hashTreeAdrs.setLayerAddress(forsTreeAdrs.getLayerAddress());
hashTreeAdrs.setTreeAddress(forsTreeAdrs.getTreeAddress());
hypertreeSignature.htSign(forsPk, sphincsDigestAndIndex.getIdxTree(), sphincsDigestAndIndex.getIdxLeaf(), hashTreeAdrs);
sig.setHypertreeSignature(hypertreeSignature);
return sig.serializeSignature(sphincsParams);
}
/**
* Verify the given signature
* the message must be given to the engine over engineUpdate
*
* @param bytes bytes of the signature
* @return true if verification was successful, false otherwise
* @throws SignatureException something is wrong with the signature
*/
@Override
protected boolean engineVerify(byte[] bytes) throws SignatureException {
return verify(bytes);
}
public boolean verify(byte[] bytes) {
SphincsSignature sphincsSignature = SphincsSignature.deserializeSignature(sphincsParams, bytes);
byte[] msgBytes = msg.toByteArray();
byte[] randomness = sphincsSignature.getRandomness();
ForsPair[] forsPairs = sphincsSignature.getForsSignature();
Fors fors = new Fors();
byte[] digest;
int[] idxTree;
int idxLeaf;
byte[] forsPk;
SphincsDigestAndIndex sphincsDigestAndIndex;
HypertreeSignature hypertreeSignature = sphincsSignature.getHypertreeSignature();
ForsTreeAdrs forsTreeAdrs = new ForsTreeAdrs();
fors.setParameter(this.sphincsParams);
// get message digest and indexes
sphincsDigestAndIndex = wotsHash.Hmsg(msgBytes, randomness);
digest = sphincsDigestAndIndex.getDigest();
idxTree = sphincsDigestAndIndex.getIdxTree();
idxLeaf = sphincsDigestAndIndex.getIdxLeaf();
//compute FORS public key from signature
forsTreeAdrs.setLayerAddress(0);
forsTreeAdrs.setTreeAddress(idxTree);
forsTreeAdrs.setKeyPairAddress(idxLeaf);
forsPk = fors.forsPkFromSig(forsPairs, a, digest, forsTreeAdrs);
// verify if the public key is correct
return hypertreeSignature.htVerify(forsPk, idxTree, idxLeaf, sphincsParams.getPkRoot(), sphincsParams.getHPrime(), sphincsParams.getD());
}
@Deprecated
protected void engineSetParameter(String s, Object o) throws InvalidParameterException {
throw new InvalidParameterException("Please use the AlgorithmParameters format with SphincsParams");
}
/**
* check if parameters are SphincsParams
*
* @param params parameters set
* @throws InvalidParameterException thrown if arguments are null or not an instance of SphincsParams
*/
@Override
protected void engineSetParameter(java.security.spec.AlgorithmParameterSpec params) throws java.security.InvalidAlgorithmParameterException {
if (params instanceof SphincsParams) {
this.sphincsParams = (SphincsParams) params;
setSphincsParams(this.sphincsParams);
} else throw new InvalidAlgorithmParameterException("Please use SphincsParams");
}
public void setParameter(AlgorithmParameterSpec params) {
this.sphincsParams = (SphincsParams) params;
setSphincsParams(this.sphincsParams);
}
private void setSphincsParams(SphincsParams s) {
this.n = s.getN();
this.a = s.getA();
wotsHash = WotsHash.algorithm(getHash(s.getHashFunction()));
wotsHash.init(s);
}
@Override
protected Object engineGetParameter(String s) throws InvalidParameterException {
return null;
}
}
package at.iaik.pq.sphincs;
import at.iaik.pq.sphincs.fors.Fors;
import at.iaik.pq.sphincs.fors.ForsPair;
import at.iaik.pq.sphincs.hypertree.HypertreeSignature;
import at.iaik.pq.sphincs.hypertree.XMSS;
import at.iaik.pq.sphincs.utils.SphincsParams;
import java.io.ByteArrayOutputStream;
import java.util.Arrays;
public class SphincsSignature {
private byte[] randomness;
private ForsPair[] forsSignature;
private HypertreeSignature hypertreeSignature;
public byte[] getRandomness() {
return randomness;
}
public void setRandomness(byte[] randomness) {
this.randomness = randomness;
}
public ForsPair[] getForsSignature() {
return forsSignature;
}
public void setForsSignature(ForsPair[] forsSignature) {
this.forsSignature = forsSignature;
}
public HypertreeSignature getHypertreeSignature() {
return hypertreeSignature;
}
public void setHypertreeSignature(HypertreeSignature hypertreeSignature) {
this.hypertreeSignature = hypertreeSignature;
}
public void setHypertreeSignature(XMSS[] xmss) {
this.hypertreeSignature = new HypertreeSignature();
this.hypertreeSignature.setXMSSSignatures(xmss);
}
public SphincsSignature(byte[] randomness, ForsPair[] forsSignature, HypertreeSignature hypertreeSignature) {
this.randomness = randomness;
this.forsSignature = forsSignature;
this.hypertreeSignature = hypertreeSignature;
}
public SphincsSignature() {
}
public static SphincsSignature deserializeSignature(SphincsParams sphincsParams, byte[] signature) {
SphincsSignature sphincsSignature = new SphincsSignature();
int n = sphincsParams.getN();
int offset=0;
//extract randomness
sphincsSignature.setRandomness(
Arrays.copyOfRange(signature, 0, n));
offset+=n;
//extract FORS Signature
ForsPair[] f=new ForsPair[sphincsParams.getK()];
for(int i=0; i<sphincsParams.getK(); i++){
f[i]=new ForsPair();
f[i].setSk(Arrays.copyOfRange(signature, offset, n+offset));
offset += n;
f[i].setAuth(Arrays.copyOfRange(signature, offset, sphincsParams.getA()*n+offset));
offset += sphincsParams.getA()*n;
}
sphincsSignature.setForsSignature(f);
//extract HT signature by using that we know where we left off with sizeFors
XMSS[] xmss=new XMSS[sphincsParams.getD()];
for(int i=0; i<sphincsParams.getD(); i++){
xmss[i]=new XMSS();
xmss[i].setParameter(sphincsParams);
xmss[i].setWotsSig(Arrays.copyOfRange(signature, offset, sphincsParams.getLen()*n+offset));
offset += sphincsParams.getLen()*n;
xmss[i].setAuth(Arrays.copyOfRange(signature, offset, sphincsParams.getHPrime()*n+offset));
offset += sphincsParams.getHPrime()*n;
}
sphincsSignature.setHypertreeSignature(xmss);
return sphincsSignature;
}
public byte[] serializeSignature(SphincsParams sphincsParams) {
ByteArrayOutputStream b = new ByteArrayOutputStream();
int d = sphincsParams.getD();
int h = sphincsParams.getH();
int len = sphincsParams.getLen();
int n = sphincsParams.getN();
ForsPair[] forsPairs = this.getForsSignature();
HypertreeSignature hypertreeSignature = this.getHypertreeSignature();
//int elements=input.length/sigLen;
//write randomness
b.write(this.getRandomness(), 0, n);
//write fors signatures
for (ForsPair forsPair : forsPairs) {
b.write(forsPair.getBytes(), 0, (sphincsParams.getA() + 1) * n);
}
//write hypertree signature, xmss for xmss
int sigLen = ((h / d) + len) * n;
for (int i = 0; i < d; i++) {
b.write(hypertreeSignature.getXMSSSignature(i).getEncoded(), 0, sigLen);
}
return b.toByteArray();
}
}
# ================================ #
# Makefile for haraka #
# Lena Heimberger, Feb 2020 #
# #
# Based on the original Haraka #
# Makefile by Stefan Kölbl #
# ================================ #
# which compiler?
CC=gcc
# output directories
CLASS_PATH=../bin
vpath %.class $(CLASS_PATH)
# optimization level (-Ox)
OPT=3
# debug enabled?
#DEB=-g
DEB=
LIB=-fPIC -shared -lc
# ======================================= #
# BE CAREFUL WHEN EDITING BELOW THIS LINE #
# ======================================= #
CF=-O$(OPT) -fomit-frame-pointer -funroll-all-loops -Wno-int-conversion $(DEB)
JNA_HEADERS_64=/usr/lib/jvm/java-13-openjdk-amd64/include
JNA_OS_HEADERS_64=/usr/lib/jvm/java-13-openjdk-amd64/include/linux
HARAKA_SRC=at_iaik_pq_sphincs_utils_HarakaUtils_Haraka.c
GCC_F=-march=nocona
AES_FLAGS=-maes -mssse3 -msse4.1
M32=-m32 -w $(GCC_F)
M64=-m64 $(GCC_F)
VERSION_AES=AES
.PHONY: all haraka-aes clean clean-pack clean-all pack-cl pack
all: clean haraka-aes
haraka-aes:
$(CC) $(LIB) $(AES_FLAGS) $(M64) $(CF) -I$(JNA_HEADERS_64) -I$(JNA_OS_HEADERS_64) $(HARAKA_SRC) -o $(CLASS_PATH)/haraka.so
clean:
@rm -f $(CLASS_PATH)/*
clean-pack:
@rm -f haraka.tar.bz2
clean-all: clean clean-pack
pack-cl: clean pack
pack: clean-pack
tar --exclude-vcs -cjf ../haraka.tar.bz2 ../`pwd | sed 's,^\(.*/\)\?\([^/]*\),\2,'`
mv ../haraka.tar.bz2 ./
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class at_iaik_pq_sphincs_utils_HarakaUtils_Haraka */
#ifndef _Included_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
#define _Included_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
#ifdef __cplusplus
extern "C" {
#endif
#undef at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_ROUNDS
#define at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_ROUNDS 5L
#undef at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_AES_ROUNDS
#define at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_AES_ROUNDS 2L
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: haraka256
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_haraka256
(JNIEnv *, jobject, jbyteArray, jbyteArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: secret_haraka256
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_secret_1haraka256
(JNIEnv *, jobject, jbyteArray, jbyteArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: haraka512
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_haraka512
(JNIEnv *, jobject, jbyteArray, jbyteArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: haraka512perm
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_haraka512perm
(JNIEnv *, jobject, jbyteArray, jbyteArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: load_constants
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_load_1constants
(JNIEnv *, jobject);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: set_constants
* Signature: ([I)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_set_1constants
(JNIEnv *, jobject, jintArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: set_secret_constants
* Signature: ([I)V
*/
JNIEXPORT void JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_set_1secret_1constants
(JNIEnv *, jobject, jintArray);
/*
* Class: at_iaik_pq_sphincs_utils_HarakaUtils_Haraka
* Method: check_for_native_instructions
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_at_iaik_pq_sphincs_utils_HarakaUtils_Haraka_check_1for_1native_1instructions
(JNIEnv *, jclass);
#ifdef __cplusplus
}
#endif
#endif
/*
Optimized Implementations for Haraka256 and Haraka512
*/
#ifndef HARAKA_H_
#define HARAKA_H_
#include "immintrin.h"
#define NUMROUNDS 5
#define u64 unsigned long
#define u128 __m128i
u128 rc[40];
u128 sseed_rc[40];
#define LOAD(src) _mm_load_si128((u128 *)(src))
#define STORE(dest,src) _mm_storeu_si128((u128 *)(dest),src)
#define AES2(s0, s1, rci) \
s0 = _mm_aesenc_si128(s0, rc[rci]); \
s1 = _mm_aesenc_si128(s1, rc[rci + 1]); \
s0 = _mm_aesenc_si128(s0, rc[rci + 2]); \
s1 = _mm_aesenc_si128(s1, rc[rci + 3]);
#define AES2SEC(s0, s1, rci) \
s0 = _mm_aesenc_si128(s0, sseed_rc[rci]); \
s1 = _mm_aesenc_si128(s1, sseed_rc[rci + 1]); \
s0 = _mm_aesenc_si128(s0, sseed_rc[rci + 2]); \
s1 = _mm_aesenc_si128(s1, sseed_rc[rci + 3]);
#define AES2_4x(s0, s1, s2, s3, rci) \
AES2(s0[0], s0[1], rci); \
AES2(s1[0], s1[1], rci); \
AES2(s2[0], s2[1], rci); \
AES2(s3[0], s3[1], rci);
#define AES2_8x(s0, s1, s2, s3, s4, s5, s6, s7, rci) \
AES2_4x(s0, s1, s2, s3, rci); \
AES2_4x(s4, s5, s6, s7, rci);
#define AES4(s0, s1, s2, s3, rci) \
s0 = _mm_aesenc_si128(s0, rc[rci]); \
s1 = _mm_aesenc_si128(s1, rc[rci + 1]); \
s2 = _mm_aesenc_si128(s2, rc[rci + 2]); \
s3 = _mm_aesenc_si128(s3, rc[rci + 3]); \
s0 = _mm_aesenc_si128(s0, rc[rci + 4]); \
s1 = _mm_aesenc_si128(s1, rc[rci + 5]); \
s2 = _mm_aesenc_si128(s2, rc[rci + 6]); \
s3 = _mm_aesenc_si128(s3, rc[rci + 7]); \
#define AES4_4x(s0, s1, s2, s3, rci) \
AES4(s0[0], s0[1], s0[2], s0[3], rci); \
AES4(s1[0], s1[1], s1[2], s1[3], rci); \
AES4(s2[0], s2[1], s2[2], s2[3], rci); \
AES4(s3[0], s3[1], s3[2], s3[3], rci);
#define AES4_8x(s0, s1, s2, s3, s4, s5, s6, s7, rci) \