import moment from "moment";

import {
  computeRelativeDate,
  tryParseRelativeDateOffsetValue,
} from "../datetime";
import { DATE_FORMAT_NO_TIME } from "../datetimeType";
import { CompoundColumnPredicateOp } from "../display-table";
import { FilledDynamicValueTableColumnType } from "../DynamicValue";
import { assertNever } from "../errors";
import { FilterCellColumnId } from "../idTypeBrands";
import {
  DialectUnsupportedError,
  PushdownSqlDialect,
  quoteSqlIdentifier,
} from "../sql";

import {
  Filter,
  FilterCellBinaryColumnPredicate,
  FilterCellListBinaryColumnPredicate,
  FilterCellUnaryColumnPredicate,
  FilterGroup,
  FilterType,
} from "./filterTypes";

export type ColumnTypeMap = {
  [key: string]: FilledDynamicValueTableColumnType;
};

export function getFilterCellSql({
  castStringColumns,
  dataframeName,
  dialect = "duckdb",
  filterType,
  filters,
  isParameterized,
}: {
  filters: FilterGroup | null;
  filterType: FilterType | null;
  dataframeName: string | null;
  dialect?: PushdownSqlDialect;
  isParameterized?: boolean;
  castStringColumns?: boolean;
}): string {
  if (filterType == null || dataframeName == null) {
    return "";
  }

  const whereClause = getSqlFromFilterGroup({
    group: filters,
    dialect,
    formatArg: jinjaArg,
    isParameterized,
    castStringColumns,
  });

  if (whereClause != null) {
    return `select *\nfrom ${dataframeName}\nwhere ${
      filterType === FilterType.REMOVE ? "not " : ""
    }${whereClause}`;
  } else {
    return `select * from ${dataframeName}`;
  }
}

export function getSqlFromFilterGroup({
  castStringColumns = false,
  dialect,
  formatArg,
  group,
  isParameterized,
}: {
  group: FilterGroup | null;
  dialect: PushdownSqlDialect;
  formatArg: (arg: FilterArg) => string;
  isParameterized?: boolean;
  castStringColumns?: boolean;
}): string | null {
  if (
    group == null ||
    group.filters.filter(filterFullyConfigured).length === 0
  ) {
    return null;
  }

  return `(${group.filters
    .filter(filterFullyConfigured)
    .map((filterOrGroup) => {
      if (FilterGroup.guard(filterOrGroup)) {
        return getSqlFromFilterGroup({
          group: filterOrGroup,
          dialect,
          formatArg,
          isParameterized,
          castStringColumns,
        });
      } else if (Filter.guard(filterOrGroup)) {
        const columnType = filterOrGroup.columnType;
        if (columnType == null) {
          return;
        }
        return getSqlFromFilter({
          filter: filterOrGroup,
          dialect,
          columnType,
          formatArg,
          isParameterized,
          castStringColumns,
        });
      }
    })
    .join(`\n       ${group.operation} `)})`;
}

type DateOperator =
  | "DATE_BEFORE"
  | "DATE_AFTER"
  | "DATE_EQUAL"
  | "DATE_EQUAL_OR_BEFORE"
  | "DATE_EQUAL_OR_AFTER"
  | "DATE_NOT_EQUAL";

export const DATE_OPERATOR_MAP: Record<DateOperator, string> = {
  DATE_BEFORE: "<",
  DATE_EQUAL_OR_BEFORE: "<=",
  DATE_AFTER: ">",
  DATE_EQUAL_OR_AFTER: ">=",
  DATE_EQUAL: "=",
  DATE_NOT_EQUAL: "!=",
};

/**
 * We wrap filter args with Jinja syntax so that our backend will extract them
 * and send them as prepared statement parameters. The filter configuration
 * can be modified by app users so we must be careful here.
 *
 * Note: There is also SQL injection potential from `filter.column` which app
 * users can also control. We cannot use prepared statements for these column
 * references but ensure safety by separately injecting code that runs before
 * this generated query that checks that all referenced column names are
 * actually columns of the referenced dataframe.
 */
function getSqlFromFilter({
  castStringColumns,
  columnType,
  dialect,
  filter,
  formatArg,
  isParameterized,
}: {
  filter: Filter;
  columnType: FilledDynamicValueTableColumnType;
  dialect: PushdownSqlDialect;
  formatArg: (arg: FilterArg) => string;
  isParameterized?: boolean;
  castStringColumns: boolean;
}): string | undefined {
  let quotedColumn = quoteSqlIdentifier(filter.column, dialect);
  if (castStringColumns && columnType === "STRING") {
    quotedColumn = tryStringCast({ dialect, expr: quotedColumn });
  }

  if (FilterCellUnaryColumnPredicate.guard(filter.operation)) {
    switch (filter.operation.op) {
      case "IS_TRUE":
        return `${quotedColumn} = ${
          dialect === "sqlserver" || dialect === "cloudsql__sqlserver"
            ? "1"
            : "true"
        }`;
      case "IS_FALSE":
        return `${quotedColumn} = ${
          dialect === "sqlserver" || dialect === "cloudsql__sqlserver"
            ? "0"
            : "false"
        }`;
      case "ALWAYS":
        return "";
      case "IS_NULL":
        return `${quotedColumn} is null`;
      case "NOT_NULL":
        return `${quotedColumn} is not null`;
      case "IS_EMPTY":
        return `${quotedColumn} = ''`;
      case "NOT_EMPTY":
        return `${quotedColumn} != ''`;
      default:
        assertNever(filter.operation, filter.operation);
    }
  } else if (FilterCellBinaryColumnPredicate.guard(filter.operation)) {
    const paramRef = formatArg({
      columnType,
      arg: filter.operation.arg,
      allowInterpolation: filter.operation.allowInterpolation,
      dialect,
    });
    switch (filter.operation.op) {
      case "GT":
        return `${quotedColumn} > ${paramRef}`;
      case "GTE":
        return `${quotedColumn} >= ${paramRef}`;
      case "EQ":
      case "NEQ": {
        const matchCase = filter.matchCase ?? columnType !== "STRING";
        const operator = filter.operation.op === "EQ" ? "=" : "!=";
        const caseHandling = caseHandlingBehavior({
          dialect,
          matchCase: Boolean(matchCase),
        });
        if (caseHandling.type === "collation") {
          return `${quotedColumn} collate ${caseHandling.collation} ${operator} ${paramRef}`;
        } else {
          return `${maybeLowerWrap({
            argOrColumnName: quotedColumn,
            matchCase: matchCase,
            columnType,
          })} ${operator} ${maybeLowerWrap({
            argOrColumnName: paramRef,
            matchCase: matchCase,
            columnType,
          })}`;
        }
      }
      case "LTE":
        return `${quotedColumn} <= ${paramRef}`;
      case "LT":
        return `${quotedColumn} < ${paramRef}`;
      case "CONTAINS":
      case "NOT_CONTAINS": {
        return contains({
          quotedColumn,
          // If parameterized, we expect the formatted arg to be "?",
          // so pass that in. If not, pass in raw value to be parsed
          arg: isParameterized ? paramRef : filter.operation.arg,
          not: filter.operation.op === "NOT_CONTAINS",
          matchCase: Boolean(filter.matchCase ?? columnType !== "STRING"),
          dialect,
          allowInterpolation: filter.operation.allowInterpolation,
          isParameterized,
        });
      }
      case "DATE_BEFORE":
      case "DATE_EQUAL":
      case "DATE_NOT_EQUAL":
      case "DATE_AFTER":
      case "DATE_EQUAL_OR_AFTER":
      case "DATE_EQUAL_OR_BEFORE": {
        const operator = DATE_OPERATOR_MAP[filter.operation.op];
        const dateColumn = dateCastColumn({
          columnType: filter.columnType,
          dialect,
          quotedColumn,
        });
        return `${dateColumn} ${operator} ${maybeDateCastValue(
          paramRef,
          dialect,
        )}`;
      }
      default:
        assertNever(filter.operation, filter.operation);
    }
  } else if (FilterCellListBinaryColumnPredicate.guard(filter.operation)) {
    const paramRef = formatArg({
      columnType,
      arg: filter.operation.arg,
      allowInterpolation: filter.operation.allowInterpolation,
      dialect,
    });
    const caseHandling = caseHandlingBehavior({
      dialect,
      matchCase: true, // list filters don't support case insensitivity
    });
    const collateClause =
      caseHandling.type === "collation"
        ? ` collate ${caseHandling.collation}`
        : "";
    switch (filter.operation.op) {
      case "IS_ONE_OF":
        return `${quotedColumn}${collateClause} in (${paramRef})`;
      case "NOT_ONE_OF":
        return `${quotedColumn}${collateClause} not in (${paramRef})`;
      case "DATE_BETWEEN": {
        if (
          filter.operation.arg.length !== 2 ||
          filter.operation.arg[0] == null ||
          filter.operation.arg[1] == null
        ) {
          return "";
        }
        const dateColumn = dateCastColumn({
          columnType: filter.columnType,
          dialect,
          quotedColumn,
        });
        if (moment(filter.operation.arg[0]).isValid()) {
          const startParamRef = formatArg({
            columnType,
            arg: filter.operation.arg[0],
            allowInterpolation: filter.operation.allowInterpolation,
            dialect,
          });
          const endParamRef = formatArg({
            columnType,
            arg: filter.operation.arg[1],
            allowInterpolation: filter.operation.allowInterpolation,
            dialect,
          });
          return `${dateColumn} between ${maybeDateCastValue(startParamRef, dialect)} and ${maybeDateCastValue(endParamRef, dialect)} `;
        } else if (filter.operation.arg[0] && filter.operation.arg[1]) {
          const parsedRelativeStartDate = tryParseRelativeDateOffsetValue(
            filter.operation.arg[0],
          );
          const parsedRelativeEndDate = tryParseRelativeDateOffsetValue(
            filter.operation.arg[1],
          );
          if (
            parsedRelativeStartDate == null ||
            parsedRelativeEndDate == null
          ) {
            return "";
          }
          const computedStartDate = computeRelativeDate(
            parsedRelativeStartDate,
          ).format(
            columnType === "DATE" ? DATE_FORMAT_NO_TIME : "YYYY-MM-DD HH:mm",
          );
          const computedEndDate = computeRelativeDate(
            parsedRelativeEndDate,
          ).format(
            columnType === "DATE" ? DATE_FORMAT_NO_TIME : "YYYY-MM-DD HH:mm",
          );
          const startParamRef = formatArg({
            columnType,
            arg: computedStartDate,
            allowInterpolation: filter.operation.allowInterpolation,
            dialect,
          });
          const endParamRef = formatArg({
            columnType,
            arg: computedEndDate,
            allowInterpolation: filter.operation.allowInterpolation,
            dialect,
          });
          return `${dateColumn} between ${maybeDateCastValue(startParamRef, dialect)} and ${maybeDateCastValue(endParamRef, dialect)} `;
        } else {
          return "";
        }
      }
      default:
        assertNever(filter.operation, filter.operation);
    }
  } else {
    assertNever(filter.operation, filter.operation);
  }
}

interface ContainsArgs {
  quotedColumn: string;
  arg: string;
  not: boolean;
  matchCase: boolean;
  dialect: PushdownSqlDialect;
  allowInterpolation: boolean | undefined;
  isParameterized?: boolean;
}
function contains({
  allowInterpolation,
  arg,
  dialect,
  isParameterized,
  matchCase,
  not,
  quotedColumn,
}: ContainsArgs): string {
  const escapedArg = escapeSpecialChars(arg);
  const escapeClause =
    arg === escapedArg ||
    // these dialects either don't support an escape clause, or default to \
    dialect === "bigquery" ||
    dialect === "clickhouse" ||
    dialect === "databricks" ||
    dialect === "mysql" ||
    dialect === "cloudsql__mysql" ||
    dialect === "mariadb" ||
    dialect === "redshift"
      ? ""
      : // snowflake makes you escape the backslash in the escape clause
        ` escape '${dialect === "snowflake" ? "\\\\" : "\\"}'`;

  let pattern = "";
  if (isParameterized) {
    pattern = concatExpression(["'%'", arg, "'%'"], dialect);
  } else if (allowInterpolation && hasJinjaExpression(arg)) {
    pattern = evaluateJinja(arg, dialect, "pattern");
  } else {
    pattern = `{{ ${JSON.stringify(`%${escapedArg}%`)} }}`;
  }

  const caseHandling = caseHandlingBehavior({ dialect, matchCase });
  const operator = caseHandling.type === "ilike" ? "ilike" : "like";
  const lhs =
    caseHandling.type === "lower" ? `lower(${quotedColumn})` : quotedColumn;
  const rhs = caseHandling.type === "lower" ? `lower(${pattern})` : pattern;
  const collateClause =
    caseHandling.type === "collation"
      ? ` collate ${caseHandling.collation}`
      : "";

  return `${lhs} ${
    not ? "not " : ""
  }${operator} ${rhs}${escapeClause}${collateClause}`;
}

interface CaseHandlingArgs {
  dialect: PushdownSqlDialect;
  matchCase: boolean;
}
type CaseHandling =
  | { type: "collation"; collation: string }
  | { type: "ilike" }
  | { type: "lower" }
  | { type: "none" };
function caseHandlingBehavior({
  dialect,
  matchCase,
}: CaseHandlingArgs): CaseHandling {
  switch (dialect) {
    case "mysql":
    case "cloudsql__mysql":
    case "mariadb":
      return {
        type: "collation",
        collation: matchCase ? "utf8mb4_bin" : "utf8mb4_general_ci",
      };
    case "sqlserver":
    case "cloudsql__sqlserver":
      return {
        type: "collation",
        collation: matchCase ? "Latin1_General_CS_AS" : "Latin1_General_CI_AI",
      };
    case "bigquery":
    case "athena":
    case "trino":
    case "starburst":
    case "dremio":
      return { type: matchCase ? "none" : "lower" };
    case "clickhouse":
    case "databricks":
    case "duckdb":
    case "motherduck":
    case "postgres":
    case "cloudsql":
    case "cloudsql__postgres":
    case "prestodb":
    case "redshift":
    case "snowflake":
    case "spark":
    case "alloydb":
    case "materialize":
      return { type: matchCase ? "none" : "ilike" };
    default:
      assertNever(dialect, dialect);
  }
}

interface DateCastColumnArgs {
  quotedColumn: string;
  columnType: FilledDynamicValueTableColumnType;
  dialect: PushdownSqlDialect;
}
function dateCastColumn({
  columnType,
  dialect,
  quotedColumn,
}: DateCastColumnArgs): string {
  if (columnType === "DATE") {
    return quotedColumn;
  }

  switch (dialect) {
    case "duckdb":
    case "motherduck": {
      return `${quotedColumn}::date`;
    }
    case "snowflake": {
      return `to_date(convert_timezone({{ hex_timezone }}, ${quotedColumn}))`;
    }
    case "bigquery": {
      return `date(cast(${quotedColumn} as timestamp), {{ hex_timezone }})`;
    }
    case "redshift":
    case "postgres":
    case "cloudsql":
    case "cloudsql__postgres":
    case "materialize":
    case "alloydb": {
      const localTimestamp = `${quotedColumn}::timestamptz at time zone {{ hex_timezone }}`;
      return `cast(${localTimestamp} as date)`;
    }
    case "athena":
    case "prestodb":
    case "starburst":
    case "trino": {
      // Athena/Trino don't support passing tz as prepared statement param
      const localTimestamp = `cast(${quotedColumn} as timestamp) at time zone '{{ hex_timezone | validatetz | sqlsafe }}'`;
      return `cast(${localTimestamp} as date)`;
    }
    case "spark":
    case "databricks": {
      // we can't get databricks to respect hex_timezone without altering the session tz
      // and we can't alter the session tz because we can't issue multi-statement queries
      return `cast(${quotedColumn} as date)`;
    }
    case "clickhouse": {
      const localTimestamp = `toDateTime(${quotedColumn}, {{ hex_timezone }})`;
      return `toDate(${localTimestamp})`;
    }
    case "mysql":
    case "cloudsql__mysql":
    case "mariadb": {
      const localTimestamp = `convert_tz(${quotedColumn}, 'UTC', {{ hex_timezone }})`;
      return `date(${localTimestamp})`;
    }
    case "sqlserver":
    case "cloudsql__sqlserver": {
      // SQL Server only understands Windows time zone names, whereas we use IANA names
      // so we have to ignore time zone
      const utcTimestamp = `${quotedColumn} at time zone 'UTC'`;
      const datepart = (part: string): string =>
        `datepart(${part}, ${utcTimestamp})`;
      return `datefromparts(${datepart("year")}, ${datepart(
        "month",
      )}, ${datepart("day")})`;
    }
    case "dremio":
      throw new DialectUnsupportedError(
        `TODO(dscott): figure out how to support Dremio`,
      );
    default:
      throw assertNever(dialect, dialect);
  }
}

function maybeDateCastValue(
  value: string,
  dialect: PushdownSqlDialect,
): string {
  switch (dialect) {
    case "redshift":
    case "materialize":
    case "postgres":
    case "cloudsql__postgres": {
      return `${value}::date`;
    }
    case "athena":
    case "starburst":
    case "trino": {
      return `cast(${value} as date)`;
    }
    default:
      // these other dialects allow comparison between dates and strings
      return value;
  }
}

function tryStringCast({
  dialect,
  expr,
}: {
  expr: string;
  dialect: PushdownSqlDialect;
}): string {
  switch (dialect) {
    case "snowflake": {
      return `to_varchar(${expr})`;
    }
    case "spark":
    case "databricks": {
      return `string(${expr})`;
    }
    case "clickhouse": {
      return `toString(${expr})`;
    }
    case "bigquery": {
      return `safe_cast(${expr} as string)`;
    }
    case "duckdb":
    case "motherduck":
    case "prestodb":
    case "starburst":
    case "trino":
    case "sqlserver":
    case "cloudsql__sqlserver": {
      return `try_cast(${expr} as varchar)`;
    }
    case "mysql":
    case "cloudsql__mysql":
    case "mariadb": {
      return `cast(${expr} as char)`;
    }
    default: {
      return `cast(${expr} as varchar)`;
    }
  }
}

function escapeSingleQuotes(str: string): string {
  return str.replaceAll("'", "\\u0027");
}

interface FilterArg {
  columnType: FilledDynamicValueTableColumnType;
  arg: string | string[];
  allowInterpolation: boolean | undefined;
  dialect: PushdownSqlDialect;
}
export function jinjaArg({
  allowInterpolation,
  arg: rawArg,
  columnType,
  dialect,
}: FilterArg): string {
  function innerJinjaArg(arg: string): string {
    switch (columnType) {
      case FilledDynamicValueTableColumnType.BOOLEAN:
        return arg.toLowerCase() === "true" ? "True" : "False";
      case FilledDynamicValueTableColumnType.NUMBER: {
        // the filter cell UI should only allow configuring numeric args, but
        // users can set arbitrary filter configs via GQL
        if (arg.trim() && !isNaN(+arg)) {
          return arg;
        } else {
          // don't throw an error here because this function gets called in the
          // UI while configuring. Just treat like a string
          return `'${escapeSingleQuotes(arg)}'`;
        }
      }
      case FilledDynamicValueTableColumnType.DATE:
      case FilledDynamicValueTableColumnType.DATETIME:
      case FilledDynamicValueTableColumnType.DATETIMETZ:
        if (moment(arg).isValid() || arg === "") {
          return `'${escapeSingleQuotes(arg)}'`;
        } else {
          const parsedRelativeDate = tryParseRelativeDateOffsetValue(arg);
          return `'${escapeSingleQuotes(
            parsedRelativeDate != null
              ? computeRelativeDate(parsedRelativeDate).format(
                  columnType === "DATE"
                    ? DATE_FORMAT_NO_TIME
                    : "YYYY-MM-DD HH:mm",
                )
              : "",
          )}'`;
        }
      case FilledDynamicValueTableColumnType.TIME:
      case FilledDynamicValueTableColumnType.STRING:
      case FilledDynamicValueTableColumnType.UNKNOWN:
        return `'${escapeSingleQuotes(arg)}'`;
      default:
        assertNever(columnType, columnType);
    }
  }

  if (Array.isArray(rawArg)) {
    const arrayArg = `[${rawArg.map(innerJinjaArg).join(", ")}]`;
    return `{{ ${arrayArg} | array }}`;
  } else if (allowInterpolation === true && hasJinjaExpression(rawArg)) {
    return evaluateJinja(rawArg, dialect, "exact");
  } else {
    return `{{ ${innerJinjaArg(rawArg)} }}`;
  }
}

function maybeLowerWrap({
  argOrColumnName,
  columnType,
  matchCase,
}: {
  argOrColumnName: string;
  matchCase?: boolean;
  columnType: FilledDynamicValueTableColumnType;
}): string {
  switch (columnType) {
    case FilledDynamicValueTableColumnType.BOOLEAN:
    case FilledDynamicValueTableColumnType.DATE:
    case FilledDynamicValueTableColumnType.TIME:
    case FilledDynamicValueTableColumnType.DATETIME:
    case FilledDynamicValueTableColumnType.DATETIMETZ:
    case FilledDynamicValueTableColumnType.NUMBER:
      return argOrColumnName;
    case FilledDynamicValueTableColumnType.STRING:
    case FilledDynamicValueTableColumnType.UNKNOWN:
      if (matchCase) {
        return argOrColumnName;
      } else {
        return `lower(${argOrColumnName})`;
      }
    default:
      assertNever(columnType, columnType);
  }
}

export function filterFullyConfigured(
  filterOrFilterGroup: Filter | FilterGroup,
): boolean {
  if (FilterGroup.guard(filterOrFilterGroup)) {
    return filterOrFilterGroup.filters.filter(filterFullyConfigured).length > 0;
  } else if (
    FilterCellUnaryColumnPredicate.guard(filterOrFilterGroup.operation)
  ) {
    return (
      filterOrFilterGroup.column != null &&
      filterOrFilterGroup.column !== "" &&
      filterOrFilterGroup.operation != null &&
      filterOrFilterGroup.columnType != null
    );
  } else if (
    FilterCellBinaryColumnPredicate.guard(filterOrFilterGroup.operation)
  ) {
    return (
      filterOrFilterGroup.column != null &&
      filterOrFilterGroup.column !== "" &&
      filterOrFilterGroup.operation != null &&
      filterOrFilterGroup.operation.arg != null &&
      filterOrFilterGroup.operation.arg !== "" &&
      filterOrFilterGroup.columnType != null
    );
  } else if (
    FilterCellListBinaryColumnPredicate.guard(filterOrFilterGroup.operation)
  ) {
    return (
      filterOrFilterGroup.column != null &&
      filterOrFilterGroup.column !== "" &&
      filterOrFilterGroup.operation != null &&
      filterOrFilterGroup.operation.arg != null &&
      filterOrFilterGroup.operation.arg.length !== 0 &&
      filterOrFilterGroup.columnType != null
    );
  } else {
    return false;
  }
}

export const BASE_FILTER: Filter = {
  column: "" as FilterCellColumnId,
  operation: {
    op: "EQ",
    arg: "",
  },
  columnType: FilledDynamicValueTableColumnType.UNKNOWN,
};

export const BASE_FILTER_GROUP: FilterGroup = {
  operation: CompoundColumnPredicateOp.AND,
  filters: [BASE_FILTER],
};

// regex instances are stateful, so we need to create a new one each time
export function getJinjaRegex(): RegExp {
  return /({{[^}]+}})/g;
}

// only jinja variables, no fancy functions
export function hasJinjaExpression(str: string): boolean {
  return getJinjaRegex().test(str);
}

function escapeSpecialChars(str: string): string {
  return str
    .replaceAll("\\", "\\\\")
    .replaceAll("%", "\\%")
    .replaceAll("_", "\\_");
}

// exported for testing
export function evaluateJinja(
  arg: string,
  dialect: PushdownSqlDialect,
  type: "pattern" | "exact",
): string {
  const parts = arg
    .split(getJinjaRegex())
    .filter((part) => part.length > 0)
    .map((part) => {
      if (part.startsWith("{{") && part.endsWith("}}")) {
        // use unescaped value, as it gets replaced by parameters
        return part;
      } else {
        return `'${escapeSpecialChars(part)}'`;
      }
    });

  const concatArgs = type === "pattern" ? ["'%'", ...parts, "'%'"] : parts;
  return concatExpression(concatArgs, dialect);
}

function concatExpression(
  parts: string[],
  dialect: PushdownSqlDialect,
): string {
  if (parts.length === 1) {
    return parts[0] ?? "";
  }

  if (dialect !== "redshift") {
    return `concat(${parts.join(", ")})`;
  }

  let ret: string = parts[0] ?? "";
  for (let i = 1; i < parts.length; i++) {
    ret = `concat(${ret}, ${parts[i]})`;
  }
  return ret;
}
