Merge branch 'main' of feistymeow.org:feisty_meow
[feisty_meow.git] / crypto / rsa_crypto.cpp
1 /*
2 *  Name   : RSA public key encryption
3 *  Author : Chris Koeritz
4 *  Purpose:
5 *    Supports public (and private) key encryption and decryption using the
6 *    OpenSSL package's support for RSA encryption.
7 ****
8 * Copyright (c) 2005-$now By Author.  This program is free software; you can  *
9 * redistribute it and/or modify it under the terms of the GNU General Public  *
10 * License as published by the Free Software Foundation; either version 2 of   *
11 * the License or (at your option) any later version.  This is online at:      *
12 *     http://www.fsf.org/copyleft/gpl.html                                    *
13 * Please send any updates to: fred@gruntose.com                               *
14 */
15
16 //note: rsa crypto provides a nice printing method...  RSA_print_fp(stdout, private_key, 0);
17
18 // notes from openssl docs: length to be encrypted in a chunk must be less than
19 // RSA_size(rsa) - 11 for the PKCS #1 v1.5 based padding modes, less than
20 // RSA_size(rsa) - 41 for RSA_PKCS1_OAEP_PADDING and exactly RSA_size(rsa)
21 // for RSA_NO_PADDING.
22
23 #include "rsa_crypto.h"
24 #include "ssl_init.h"
25
26 #include <basis/functions.h>
27 #include <loggers/critical_events.h>
28 #include <loggers/program_wide_logger.h>
29 #include <mathematics/chaos.h>
30 #include <structures/object_packers.h>
31 #include <structures/static_memory_gremlin.h>
32
33 #include <openssl/bn.h>
34 #include <openssl/err.h>
35 #include <openssl/rsa.h>
36
37 using namespace basis;
38 using namespace loggers;
39 using namespace mathematics;
40 using namespace structures;
41
42 namespace crypto {
43
44 //#define DEBUG_RSA_CRYPTO
45   // uncomment for noisier version.
46
47 #ifdef DEBUG_RSA_CRYPTO
48   #undef LOG
49   #define LOG(s) CLASS_EMERGENCY_LOG(program_wide_logger::get(), s)
50 #else
51   #undef LOG
52   #define LOG(s)
53 #endif
54
55 SAFE_STATIC(mutex, __single_stepper, )
56   // protects unsafe areas of rsa crypto from access by multiple threads at once.
57
58 rsa_crypto::rsa_crypto(int key_size)
59 : _key(NULL_POINTER)
60 {
61   FUNCDEF("ctor(int)");
62   LOG("prior to generating key");
63   _key = generate_key(key_size);  // generate_key initializes ssl for us.
64   LOG("after generating key");
65 }
66
67 rsa_crypto::rsa_crypto(const byte_array &key)
68 : _key(NULL_POINTER)
69 {
70   FUNCDEF("ctor(byte_array)");
71   static_ssl_initializer();
72   byte_array key_copy = key;
73   LOG("prior to set key");
74   set_key(key_copy);
75   LOG("after set key");
76 }
77
78 rsa_crypto::rsa_crypto(RSA *key)
79 : _key(NULL_POINTER)
80 {
81   FUNCDEF("ctor(RSA)");
82   static_ssl_initializer();
83   LOG("prior to set key");
84   set_key(key);
85   LOG("after set key");
86 }
87
88 rsa_crypto::rsa_crypto(const rsa_crypto &to_copy)
89 : root_object(),
90   _key(NULL_POINTER)
91 {
92   FUNCDEF("copy ctor");
93   static_ssl_initializer();
94   LOG("prior to set key");
95   set_key(to_copy._key);
96   LOG("after set key");
97 }
98
99 rsa_crypto::~rsa_crypto()
100 {
101   FUNCDEF("destructor");
102   LOG("prior to rsa free");
103   auto_synchronizer mutt(__single_stepper());
104   RSA_free(_key);
105   LOG("after rsa free");
106 }
107
108 const rsa_crypto &rsa_crypto::operator = (const rsa_crypto &to_copy)
109 {
110   if (this == &to_copy) return *this;
111   set_key(to_copy._key);
112   return *this;
113 }
114
115 RSA *rsa_crypto::generate_key(int key_size)
116 {
117   FUNCDEF("generate_key");
118   if (key_size < 4) key_size = 4;  // laughable lower default.
119   static_ssl_initializer();
120   LOG("into generate key");
121   auto_synchronizer mutt(__single_stepper());
122   RSA *to_return = RSA_new();
123   BIGNUM *e = BN_new();
124   BN_set_word(e, 65537);
125 //hmmm: only one value of e?
126   int ret = RSA_generate_key_ex(to_return, key_size, e, NULL_POINTER);
127   if (!ret) {
128     continuable_error(static_class_name(), func,
129         a_sprintf("failed to generate a key of %d bits: error is %ld.", key_size, ERR_get_error()));
130     BN_free(e);
131     RSA_free(to_return);
132     return NULL;
133   }
134   LOG("after key generated");
135   BN_free(e);
136   return to_return;
137 }
138
139 bool rsa_crypto::check_key(RSA *key)
140 {
141   auto_synchronizer mutt(__single_stepper());
142   return RSA_check_key(key) == 1;
143 }
144
145 bool rsa_crypto::set_key(byte_array &key)
146 {
147   FUNCDEF("set_key [byte_array]");
148   if (!key.length()) return false;
149   auto_synchronizer mutt(__single_stepper());
150   if (_key) RSA_free(_key);
151   _key = RSA_new();
152   abyte type;
153   if (!structures::detach(key, type)) return false;
154   if ( (type != 'r') && (type != 'u') ) return false;
155   // get the public key bits first.
156   byte_array n;
157   if (!structures::detach(key, n)) return false;
158   BIGNUM *the_n = BN_bin2bn(n.access(), n.length(), NULL_POINTER);
159   if (!the_n) return false;
160   byte_array e;
161   if (!structures::detach(key, e)) return false;
162   BIGNUM *the_e = BN_bin2bn(e.access(), e.length(), NULL_POINTER);
163   if (!the_e) return false;
164
165   if (type == 'u') {
166      // done with public key.
167 #ifdef NEWER_OPENSSL
168      RSA_set0_key(_key, the_n, the_e, NULL_POINTER);
169 #else
170      _key->n = the_n; _key->e = the_e;
171 #endif
172      return true;
173   }
174
175   // the rest is for a private key.
176   byte_array d;
177   if (!structures::detach(key, d)) return false;
178   BIGNUM *the_d = BN_bin2bn(d.access(), d.length(), NULL_POINTER);
179   if (!the_d) return false;
180
181   byte_array p;
182   if (!structures::detach(key, p)) return false;
183   BIGNUM *the_p = BN_bin2bn(p.access(), p.length(), NULL_POINTER);
184   if (!the_p) return false;
185   byte_array q;
186   if (!structures::detach(key, q)) return false;
187   BIGNUM *the_q = BN_bin2bn(q.access(), q.length(), NULL_POINTER);
188   if (!the_q) return false;
189   byte_array dmp1;
190   if (!structures::detach(key, dmp1)) return false;
191   BIGNUM *the_dmp1 = BN_bin2bn(dmp1.access(), dmp1.length(), NULL_POINTER);
192   if (!the_dmp1) return false;
193   byte_array dmq1;
194   if (!structures::detach(key, dmq1)) return false;
195   BIGNUM *the_dmq1 = BN_bin2bn(dmq1.access(), dmq1.length(), NULL_POINTER);
196   if (!the_dmq1) return false;
197   byte_array iqmp;
198   if (!structures::detach(key, iqmp)) return false;
199   BIGNUM *the_iqmp = BN_bin2bn(iqmp.access(), iqmp.length(), NULL_POINTER);
200   if (!the_iqmp) return false;
201
202   // we can set the n, e and d now.
203 #ifdef NEWER_OPENSSL
204   int ret = RSA_set0_key(_key, the_n, the_e, the_d);
205   if (ret != 1) return false;
206   ret = RSA_set0_factors(_key, the_p, the_q);
207   if (ret != 1) return false;
208   ret = RSA_set0_crt_params(_key, the_dmp1, the_dmq1, the_iqmp);
209   if (ret != 1) return false;
210 #else
211   _key->n = the_n; _key->e = the_e; _key->d = the_d;
212   _key->p = the_p; _key->q = the_q;
213   _key->dmp1 = the_dmp1; _key->dmq1 = the_dmq1; _key->iqmp = the_iqmp;
214 #endif
215
216   int check = RSA_check_key(_key);
217   if (check != 1) {
218     continuable_error(static_class_name(), func, "failed to check the private "
219         "portion of the key!");
220     return false;
221   }
222
223   return true;
224 }
225
226 bool rsa_crypto::set_key(RSA *key)
227 {
228   FUNCDEF("set_key [RSA]");
229   if (!key) return NULL_POINTER;
230   // test the incoming key.
231   auto_synchronizer mutt(__single_stepper());
232   int check = RSA_check_key(key);
233   if (check != 1) return false;
234   // clean out the old key.
235   if (_key) RSA_free(_key);
236   _key = RSAPrivateKey_dup(key);
237   if (!_key) {
238     continuable_error(static_class_name(), func, "failed to create a "
239         "duplicate of the key!");
240     return false;
241   }
242   return true;
243 }
244
245 bool rsa_crypto::public_key(byte_array &pubkey) const
246 {
247   FUNCDEF("public_key");
248   if (!_key) return false;
249   structures::attach(pubkey, abyte('u'));  // signal a public key.
250   // convert the two public portions into binary.
251   BIGNUM **the_n = new BIGNUM *, **the_e = new BIGNUM *, **the_d = new BIGNUM *;
252 #ifdef NEWER_OPENSSL
253   RSA_get0_key(_key, (const BIGNUM **)the_n, (const BIGNUM **)the_e, (const BIGNUM **)the_d);
254 #else
255   *the_n = _key->n; *the_e = _key->e; *the_d = _key->d;
256 #endif
257   byte_array n(BN_num_bytes(*the_n));
258   int ret = BN_bn2bin(*the_n, n.access());
259   byte_array e(BN_num_bytes(*the_e));
260   ret = BN_bn2bin(*the_e, e.access());
261   // pack those two chunks.
262   structures::attach(pubkey, n);
263   structures::attach(pubkey, e);
264   WHACK(the_n); WHACK(the_e); WHACK(the_d);
265
266   return true;
267 }
268
269 bool rsa_crypto::private_key(byte_array &privkey) const
270 {
271   FUNCDEF("private_key");
272   if (!_key) return false;
273   int posn = privkey.length();
274   bool worked = public_key(privkey);  // get the public pieces first.
275   if (!worked) return false;
276   privkey[posn] = abyte('r');  // switch public key flag to private.
277   // convert the multiple private portions into binary.
278   //const BIGNUM **the_n = NULL_POINTER, **the_e = NULL_POINTER, **the_d = NULL_POINTER;
279   BIGNUM **the_n = new BIGNUM *, **the_e = new BIGNUM *, **the_d = new BIGNUM *;
280   BIGNUM **the_p = new BIGNUM *, **the_q = new BIGNUM *;
281   BIGNUM **the_dmp1 = new BIGNUM *, **the_dmq1 = new BIGNUM *, **the_iqmp = new BIGNUM *;
282 #ifdef NEWER_OPENSSL
283   RSA_get0_key(_key, (const BIGNUM **)the_n, (const BIGNUM **)the_e, (const BIGNUM **)the_d);
284   RSA_get0_factors(_key, (const BIGNUM **)the_p, (const BIGNUM **)the_q);
285   RSA_get0_crt_params(_key, (const BIGNUM **)the_dmp1, (const BIGNUM **)the_dmq1, (const BIGNUM **)the_iqmp);
286 #else
287   *the_n = _key->n; *the_e = _key->e; *the_d = _key->d;
288   *the_p = _key->p; *the_q = _key->q;
289   *the_dmp1 = _key->dmp1; *the_dmq1 = _key->dmq1; *the_iqmp = _key->iqmp;
290 #endif
291   byte_array d(BN_num_bytes(*the_d));
292   int ret = BN_bn2bin(*the_d, d.access());
293   byte_array p(BN_num_bytes(*the_p));
294   ret = BN_bn2bin(*the_p, p.access());
295   byte_array q(BN_num_bytes(*the_q));
296   ret = BN_bn2bin(*the_q, q.access());
297   byte_array dmp1(BN_num_bytes(*the_dmp1));
298   ret = BN_bn2bin(*the_dmp1, dmp1.access());
299   byte_array dmq1(BN_num_bytes(*the_dmq1));
300   ret = BN_bn2bin(*the_dmq1, dmq1.access());
301   byte_array iqmp(BN_num_bytes(*the_iqmp));
302   ret = BN_bn2bin(*the_iqmp, iqmp.access());
303   // pack all those in now.
304   structures::attach(privkey, d);
305   structures::attach(privkey, p);
306   structures::attach(privkey, q);
307   structures::attach(privkey, dmp1);
308   structures::attach(privkey, dmq1);
309   structures::attach(privkey, iqmp);
310   return true;
311 }
312
313 bool rsa_crypto::public_encrypt(const byte_array &source,
314     byte_array &target) const
315 {
316   FUNCDEF("public_encrypt");
317   target.reset();
318   if (!source.length()) return false;
319
320   auto_synchronizer mutt(__single_stepper());
321   const int max_chunk = RSA_size(_key) - 12;
322
323   byte_array encoded(RSA_size(_key));
324   for (int i = 0; i < source.length(); i += max_chunk) {
325     int edge = i + max_chunk - 1;
326     if (edge > source.last())
327       edge = source.last();
328     int next_chunk = edge - i + 1;
329     RSA_public_encrypt(next_chunk, &source[i],
330         encoded.access(), _key, RSA_PKCS1_PADDING);
331     target += encoded;
332   }
333   return true;
334 }
335
336 bool rsa_crypto::private_decrypt(const byte_array &source,
337     byte_array &target) const
338 {
339   FUNCDEF("private_decrypt");
340   target.reset();
341   if (!source.length()) return false;
342
343   auto_synchronizer mutt(__single_stepper());
344   const int max_chunk = RSA_size(_key);
345
346   byte_array decoded(max_chunk);
347   for (int i = 0; i < source.length(); i += max_chunk) {
348     int edge = i + max_chunk - 1;
349     if (edge > source.last())
350       edge = source.last();
351     int next_chunk = edge - i + 1;
352     int dec_size = RSA_private_decrypt(next_chunk, &source[i],
353         decoded.access(), _key, RSA_PKCS1_PADDING);
354     if (dec_size < 0) return false;  // that didn't work.
355     decoded.zap(dec_size, decoded.last());
356     target += decoded;
357     decoded.reset(max_chunk);
358   }
359   return true;
360 }
361
362 bool rsa_crypto::private_encrypt(const byte_array &source,
363     byte_array &target) const
364 {
365   FUNCDEF("private_encrypt");
366   target.reset();
367   if (!source.length()) return false;
368
369   auto_synchronizer mutt(__single_stepper());
370   const int max_chunk = RSA_size(_key) - 12;
371
372   byte_array encoded(RSA_size(_key));
373   for (int i = 0; i < source.length(); i += max_chunk) {
374     int edge = i + max_chunk - 1;
375     if (edge > source.last())
376       edge = source.last();
377     int next_chunk = edge - i + 1;
378     RSA_private_encrypt(next_chunk, &source[i],
379         encoded.access(), _key, RSA_PKCS1_PADDING);
380     target += encoded;
381   }
382   return true;
383 }
384
385 bool rsa_crypto::public_decrypt(const byte_array &source,
386     byte_array &target) const
387 {
388   FUNCDEF("public_decrypt");
389   target.reset();
390   if (!source.length()) return false;
391
392   auto_synchronizer mutt(__single_stepper());
393   const int max_chunk = RSA_size(_key);
394
395   byte_array decoded(max_chunk);
396   for (int i = 0; i < source.length(); i += max_chunk) {
397     int edge = i + max_chunk - 1;
398     if (edge > source.last())
399       edge = source.last();
400     int next_chunk = edge - i + 1;
401     int dec_size = RSA_public_decrypt(next_chunk, &source[i],
402         decoded.access(), _key, RSA_PKCS1_PADDING);
403     if (dec_size < 0) return false;  // that didn't work.
404     decoded.zap(dec_size, decoded.last());
405     target += decoded;
406     decoded.reset(max_chunk);
407   }
408   return true;
409 }
410
411 } //namespace.
412