146 KASSERT(file,
"Invalid file stream");
148 unsigned int weights_rows = 0;
149 KASSERT(ReadUnsignedInt(file, &weights_rows),
"Expected weight rows");
150 KASSERT(weights_rows > 0,
"Invalid weights # rows");
152 unsigned int weights_cols = 0;
153 KASSERT(ReadUnsignedInt(file, &weights_cols),
"Expected weight cols");
154 KASSERT(weights_cols > 0,
"Invalid weights shape");
156 unsigned int biases_shape = 0;
157 KASSERT(ReadUnsignedInt(file, &biases_shape),
"Expected biases shape");
158 KASSERT(biases_shape > 0,
"Invalid biases shape");
160 weights_.Resize(weights_rows, weights_cols);
162 ReadFloats(file, weights_.data_.data(), weights_rows * weights_cols),
165 biases_.Resize(biases_shape);
166 KASSERT(ReadFloats(file, biases_.data_.data(), biases_shape),
169 KASSERT(activation_.LoadLayer(file),
"Failed to load activation");
207 KASSERT(file,
"Invalid file stream");
209 unsigned int weights_i = 0;
210 KASSERT(ReadUnsignedInt(file, &weights_i),
"Expected weights_i");
211 KASSERT(weights_i > 0,
"Invalid weights # i");
213 unsigned int weights_j = 0;
214 KASSERT(ReadUnsignedInt(file, &weights_j),
"Expected weights_j");
215 KASSERT(weights_j > 0,
"Invalid weights # j");
217 unsigned int weights_k = 0;
218 KASSERT(ReadUnsignedInt(file, &weights_k),
"Expected weights_k");
219 KASSERT(weights_k > 0,
"Invalid weights # k");
221 unsigned int weights_l = 0;
222 KASSERT(ReadUnsignedInt(file, &weights_l),
"Expected weights_l");
223 KASSERT(weights_l > 0,
"Invalid weights # l");
225 unsigned int biases_shape = 0;
226 KASSERT(ReadUnsignedInt(file, &biases_shape),
"Expected biases shape");
227 KASSERT(biases_shape > 0,
"Invalid biases shape");
229 weights_.Resize(weights_i, weights_j, weights_k, weights_l);
230 KASSERT(ReadFloats(file, weights_.data_.data(),
231 weights_i * weights_j * weights_k * weights_l),
234 biases_.Resize(biases_shape);
235 KASSERT(ReadFloats(file, biases_.data_.data(), biases_shape),
238 KASSERT(activation_.LoadLayer(file),
"Failed to load activation");
244 KASSERT(in,
"Invalid input");
245 KASSERT(out,
"Invalid output");
247 KASSERT(in->
dims_[0] == weights_.dims_[1],
248 "Input 'depth' doesn't match kernel 'depth'");
250 int st_nj = (weights_.dims_[2] - 1) / 2;
251 int st_pj = (weights_.dims_[2]) / 2;
252 int st_nk = (weights_.dims_[3] - 1) / 2;
253 int st_pk = (weights_.dims_[3]) / 2;
255 Tensor tmp(weights_.dims_[0], in->
dims_[1] - st_nj - st_pj,
256 in->
dims_[2] - st_nk - st_pk);
259 for (
int i = 0; i < weights_.dims_[0]; i++) {
261 for (
int j = 0; j < weights_.dims_[1]; j++) {
263 for (
int tj = st_nj; tj < in->
dims_[1] - st_pj; tj++) {
264 for (
int tk = st_nk; tk < in->
dims_[2] - st_pk; tk++) {
266 for (
int k = 0; k < weights_.dims_[2]; k++) {
267 for (
int l = 0; l < weights_.dims_[3]; l++) {
268 const float weight = weights_(i, j, k, l);
270 (*in)(j, tj - st_nj + k, tk - st_nk + l);
272 tmp(i, tj - st_nj, tk - st_nk) += weight * value;
280 for (
int j = 0; j < tmp.
dims_[1]; j++) {
281 for (
int k = 0; k < tmp.
dims_[2]; k++) {
282 tmp(i, j, k) += biases_(i);
287 KASSERT(activation_.Apply(&tmp, out),
"Failed to apply activation");
378 KASSERT(file,
"Invalid file stream");
380 unsigned int wi_rows = 0;
381 KASSERT(ReadUnsignedInt(file, &wi_rows),
"Expected Wi rows");
382 KASSERT(wi_rows > 0,
"Invalid Wi # rows");
384 unsigned int wi_cols = 0;
385 KASSERT(ReadUnsignedInt(file, &wi_cols),
"Expected Wi cols");
386 KASSERT(wi_cols > 0,
"Invalid Wi shape");
388 unsigned int ui_rows = 0;
389 KASSERT(ReadUnsignedInt(file, &ui_rows),
"Expected Ui rows");
390 KASSERT(ui_rows > 0,
"Invalid Ui # rows");
392 unsigned int ui_cols = 0;
393 KASSERT(ReadUnsignedInt(file, &ui_cols),
"Expected Ui cols");
394 KASSERT(ui_cols > 0,
"Invalid Ui shape");
396 unsigned int bi_shape = 0;
397 KASSERT(ReadUnsignedInt(file, &bi_shape),
"Expected bi shape");
398 KASSERT(bi_shape > 0,
"Invalid bi shape");
400 unsigned int wf_rows = 0;
401 KASSERT(ReadUnsignedInt(file, &wf_rows),
"Expected Wf rows");
402 KASSERT(wf_rows > 0,
"Invalid Wf # rows");
404 unsigned int wf_cols = 0;
405 KASSERT(ReadUnsignedInt(file, &wf_cols),
"Expected Wf cols");
406 KASSERT(wf_cols > 0,
"Invalid Wf shape");
408 unsigned int uf_rows = 0;
409 KASSERT(ReadUnsignedInt(file, &uf_rows),
"Expected Uf rows");
410 KASSERT(uf_rows > 0,
"Invalid Uf # rows");
412 unsigned int uf_cols = 0;
413 KASSERT(ReadUnsignedInt(file, &uf_cols),
"Expected Uf cols");
414 KASSERT(uf_cols > 0,
"Invalid Uf shape");
416 unsigned int bf_shape = 0;
417 KASSERT(ReadUnsignedInt(file, &bf_shape),
"Expected bf shape");
418 KASSERT(bf_shape > 0,
"Invalid bf shape");
420 unsigned int wc_rows = 0;
421 KASSERT(ReadUnsignedInt(file, &wc_rows),
"Expected Wc rows");
422 KASSERT(wc_rows > 0,
"Invalid Wc # rows");
424 unsigned int wc_cols = 0;
425 KASSERT(ReadUnsignedInt(file, &wc_cols),
"Expected Wc cols");
426 KASSERT(wc_cols > 0,
"Invalid Wc shape");
428 unsigned int uc_rows = 0;
429 KASSERT(ReadUnsignedInt(file, &uc_rows),
"Expected Uc rows");
430 KASSERT(uc_rows > 0,
"Invalid Uc # rows");
432 unsigned int uc_cols = 0;
433 KASSERT(ReadUnsignedInt(file, &uc_cols),
"Expected Uc cols");
434 KASSERT(uc_cols > 0,
"Invalid Uc shape");
436 unsigned int bc_shape = 0;
437 KASSERT(ReadUnsignedInt(file, &bc_shape),
"Expected bc shape");
438 KASSERT(bc_shape > 0,
"Invalid bc shape");
440 unsigned int wo_rows = 0;
441 KASSERT(ReadUnsignedInt(file, &wo_rows),
"Expected Wo rows");
442 KASSERT(wo_rows > 0,
"Invalid Wo # rows");
444 unsigned int wo_cols = 0;
445 KASSERT(ReadUnsignedInt(file, &wo_cols),
"Expected Wo cols");
446 KASSERT(wo_cols > 0,
"Invalid Wo shape");
448 unsigned int uo_rows = 0;
449 KASSERT(ReadUnsignedInt(file, &uo_rows),
"Expected Uo rows");
450 KASSERT(uo_rows > 0,
"Invalid Uo # rows");
452 unsigned int uo_cols = 0;
453 KASSERT(ReadUnsignedInt(file, &uo_cols),
"Expected Uo cols");
454 KASSERT(uo_cols > 0,
"Invalid Uo shape");
456 unsigned int bo_shape = 0;
457 KASSERT(ReadUnsignedInt(file, &bo_shape),
"Expected bo shape");
458 KASSERT(bo_shape > 0,
"Invalid bo shape");
461 Wi_.Resize(wi_rows, wi_cols);
462 KASSERT(ReadFloats(file, Wi_.data_.data(), wi_rows * wi_cols),
463 "Expected Wi weights");
465 Ui_.Resize(ui_rows, ui_cols);
466 KASSERT(ReadFloats(file, Ui_.data_.data(), ui_rows * ui_cols),
467 "Expected Ui weights");
469 bi_.Resize(1, bi_shape);
470 KASSERT(ReadFloats(file, bi_.data_.data(), bi_shape),
"Expected bi biases");
473 Wf_.Resize(wf_rows, wf_cols);
474 KASSERT(ReadFloats(file, Wf_.data_.data(), wf_rows * wf_cols),
475 "Expected Wf weights");
477 Uf_.Resize(uf_rows, uf_cols);
478 KASSERT(ReadFloats(file, Uf_.data_.data(), uf_rows * uf_cols),
479 "Expected Uf weights");
481 bf_.Resize(1, bf_shape);
482 KASSERT(ReadFloats(file, bf_.data_.data(), bf_shape),
"Expected bf biases");
485 Wc_.Resize(wc_rows, wc_cols);
486 KASSERT(ReadFloats(file, Wc_.data_.data(), wc_rows * wc_cols),
487 "Expected Wc weights");
489 Uc_.Resize(uc_rows, uc_cols);
490 KASSERT(ReadFloats(file, Uc_.data_.data(), uc_rows * uc_cols),
491 "Expected Uc weights");
493 bc_.Resize(1, bc_shape);
494 KASSERT(ReadFloats(file, bc_.data_.data(), bc_shape),
"Expected bc biases");
497 Wo_.Resize(wo_rows, wo_cols);
498 KASSERT(ReadFloats(file, Wo_.data_.data(), wo_rows * wo_cols),
499 "Expected Wo weights");
501 Uo_.Resize(uo_rows, uo_cols);
502 KASSERT(ReadFloats(file, Uo_.data_.data(), uo_rows * uo_cols),
503 "Expected Uo weights");
505 bo_.Resize(1, bo_shape);
506 KASSERT(ReadFloats(file, bo_.data_.data(), bo_shape),
"Expected bo biases");
508 KASSERT(innerActivation_.LoadLayer(file),
509 "Failed to load inner activation");
510 KASSERT(activation_.LoadLayer(file),
"Failed to load activation");
512 unsigned int return_sequences = 0;
513 KASSERT(ReadUnsignedInt(file, &return_sequences),
514 "Expected return_sequences param");
515 return_sequences_ = (bool)return_sequences;