/**
 * @typedef {number} Cell
 * @typedef {Cell[][]|Cell[][][]} Matrix
 * @typedef {number[]} Shape
 * @typedef {number[]} CellIndices
 */

/**
 * Gets the matrix's shape.
 *
 * @param {Matrix} m
 * @returns {Shape}
 */
export const shape = (m) => {
  const shapes = [];
  let dimension = m;
  while (dimension && Array.isArray(dimension)) {
    shapes.push(dimension.length);
    dimension = (dimension.length && [...dimension][0]) || null;
  }
  return shapes;
};

/**
 * Checks if matrix has a correct type.
 *
 * @param {Matrix} m
 * @throws {Error}
 */
const validateType = (m) => {
  if (
    !m
    || !Array.isArray(m)
    || !Array.isArray(m[0])
  ) {
    throw new Error('Invalid matrix format');
  }
};

/**
 * Checks if matrix is two dimensional.
 *
 * @param {Matrix} m
 * @throws {Error}
 */
const validate2D = (m) => {
  validateType(m);
  const aShape = shape(m);
  if (aShape.length !== 2) {
    throw new Error('Matrix is not of 2D shape');
  }
};

/**
 * Validates that matrices are of the same shape.
 *
 * @param {Matrix} a
 * @param {Matrix} b
 * @trows {Error}
 */
export const validateSameShape = (a, b) => {
  validateType(a);
  validateType(b);

  const aShape = shape(a);
  const bShape = shape(b);

  if (aShape.length !== bShape.length) {
    throw new Error('Matrices have different dimensions');
  }

  while (aShape.length && bShape.length) {
    if (aShape.pop() !== bShape.pop()) {
      throw new Error('Matrices have different shapes');
    }
  }
};

/**
 * Generates the matrix of specific shape with specific values.
 *
 * @param {Shape} mShape - the shape of the matrix to generate
 * @param {function({CellIndex}): Cell} fill - cell values of a generated matrix.
 * @returns {Matrix}
 */
export const generate = (mShape, fill) => {
  /**
   * Generates the matrix recursively.
   *
   * @param {Shape} recShape - the shape of the matrix to generate
   * @param {CellIndices} recIndices
   * @returns {Matrix}
   */
  const generateRecursively = (recShape, recIndices) => {
    if (recShape.length === 1) {
      return Array(recShape[0])
        .fill(null)
        .map((cellValue, cellIndex) => fill([...recIndices, cellIndex]));
    }
    const m = [];
    for (let i = 0; i < recShape[0]; i += 1) {
      m.push(generateRecursively(recShape.slice(1), [...recIndices, i]));
    }
    return m;
  };

  return generateRecursively(mShape, []);
};

/**
 * Generates the matrix of zeros of specified shape.
 *
 * @param {Shape} mShape - shape of the matrix
 * @returns {Matrix}
 */
export const zeros = (mShape) => {
  return generate(mShape, () => 0);
};

/**
 * @param {Matrix} a
 * @param {Matrix} b
 * @return Matrix
 * @throws {Error}
 */
export const dot = (a, b) => {
  // Validate inputs.
  validate2D(a);
  validate2D(b);

  // Check dimensions.
  const aShape = shape(a);
  const bShape = shape(b);
  if (aShape[1] !== bShape[0]) {
    throw new Error('Matrices have incompatible shape for multiplication');
  }

  // Perform matrix multiplication.
  const outputShape = [aShape[0], bShape[1]];
  const c = zeros(outputShape);

  for (let bCol = 0; bCol < b[0].length; bCol += 1) {
    for (let aRow = 0; aRow < a.length; aRow += 1) {
      let cellSum = 0;
      for (let aCol = 0; aCol < a[aRow].length; aCol += 1) {
        cellSum += a[aRow][aCol] * b[aCol][bCol];
      }
      c[aRow][bCol] = cellSum;
    }
  }

  return c;
};

/**
 * Transposes the matrix.
 *
 * @param {Matrix} m
 * @returns Matrix
 * @throws {Error}
 */
export const t = (m) => {
  validate2D(m);
  const mShape = shape(m);
  const transposed = zeros([mShape[1], mShape[0]]);
  for (let row = 0; row < m.length; row += 1) {
    for (let col = 0; col < m[0].length; col += 1) {
      transposed[col][row] = m[row][col];
    }
  }
  return transposed;
};

/**
 * Traverses the matrix.
 *
 * @param {Matrix} m
 * @param {function(indices: CellIndices, c: Cell)} visit
 */
export const walk = (m, visit) => {
  /**
   * Traverses the matrix recursively.
   *
   * @param {Matrix} recM
   * @param {CellIndices} cellIndices
   * @return {Matrix}
   */
  const recWalk = (recM, cellIndices) => {
    const recMShape = shape(recM);

    if (recMShape.length === 1) {
      for (let i = 0; i < recM.length; i += 1) {
        visit([...cellIndices, i], recM[i]);
      }
    }
    for (let i = 0; i < recM.length; i += 1) {
      recWalk(recM[i], [...cellIndices, i]);
    }
  };

  recWalk(m, []);
};

/**
 * Gets the matrix cell value at specific index.
 *
 * @param {Matrix} m - Matrix that contains the cell that needs to be updated
 * @param {CellIndices} cellIndices - Array of cell indices
 * @return {Cell}
 */
export const getCellAtIndex = (m, cellIndices) => {
  // We start from the row at specific index.
  let cell = m[cellIndices[0]];
  // Going deeper into the next dimensions but not to the last one to preserve
  // the pointer to the last dimension array.
  for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
    cell = cell[cellIndices[dimIdx]];
  }
  // At this moment the cell variable points to the array at the last needed dimension.
  return cell[cellIndices[cellIndices.length - 1]];
};

/**
 * Update the matrix cell at specific index.
 *
 * @param {Matrix} m - Matrix that contains the cell that needs to be updated
 * @param {CellIndices} cellIndices - Array of cell indices
 * @param {Cell} cellValue - New cell value
 */
export const updateCellAtIndex = (m, cellIndices, cellValue) => {
  // We start from the row at specific index.
  let cell = m[cellIndices[0]];
  // Going deeper into the next dimensions but not to the last one to preserve
  // the pointer to the last dimension array.
  for (let dimIdx = 1; dimIdx < cellIndices.length - 1; dimIdx += 1) {
    cell = cell[cellIndices[dimIdx]];
  }
  // At this moment the cell variable points to the array at the last needed dimension.
  cell[cellIndices[cellIndices.length - 1]] = cellValue;
};

/**
 * Adds two matrices element-wise.
 *
 * @param {Matrix} a
 * @param {Matrix} b
 * @return {Matrix}
 */
export const add = (a, b) => {
  validateSameShape(a, b);
  const result = zeros(shape(a));

  walk(a, (cellIndices, cellValue) => {
    updateCellAtIndex(result, cellIndices, cellValue);
  });

  walk(b, (cellIndices, cellValue) => {
    const currentCellValue = getCellAtIndex(result, cellIndices);
    updateCellAtIndex(result, cellIndices, currentCellValue + cellValue);
  });

  return result;
};

/**
 * Multiplies two matrices element-wise.
 *
 * @param {Matrix} a
 * @param {Matrix} b
 * @return {Matrix}
 */
export const mul = (a, b) => {
  validateSameShape(a, b);
  const result = zeros(shape(a));

  walk(a, (cellIndices, cellValue) => {
    updateCellAtIndex(result, cellIndices, cellValue);
  });

  walk(b, (cellIndices, cellValue) => {
    const currentCellValue = getCellAtIndex(result, cellIndices);
    updateCellAtIndex(result, cellIndices, currentCellValue * cellValue);
  });

  return result;
};

/**
 * Subtract two matrices element-wise.
 *
 * @param {Matrix} a
 * @param {Matrix} b
 * @return {Matrix}
 */
export const sub = (a, b) => {
  validateSameShape(a, b);
  const result = zeros(shape(a));

  walk(a, (cellIndices, cellValue) => {
    updateCellAtIndex(result, cellIndices, cellValue);
  });

  walk(b, (cellIndices, cellValue) => {
    const currentCellValue = getCellAtIndex(result, cellIndices);
    updateCellAtIndex(result, cellIndices, currentCellValue - cellValue);
  });

  return result;
};