#pragma once

#include "location.hh"
#include "utils.h"
#include <map>
#include <string>
#include <vector>
#include <unordered_set>

#include "types.h"

namespace bpftrace {
namespace ast {

class Visitor;

class Node {
public:
  virtual ~Node() { }
  virtual void accept(Visitor &v) = 0;
  location loc;
  Node() : loc(location()){};
  Node(location loc) : loc(loc){};
};

class Map;
class Variable;
class Expression : public Node {
public:
  SizedType type;
  Map *key_for_map = nullptr;
  Map *map = nullptr; // Only set when this expression is assigned to a map
  Variable *var = nullptr; // Set when this expression is assigned to a variable
  bool is_literal = false;
  bool is_variable = false;
  bool is_map = false;
  Expression() : Node(){};
  Expression(location loc) : Node(loc){};
  static std::unordered_set<std::string>& getResolve();
};
using ExpressionList = std::vector<Expression *>;

class Integer : public Expression {
public:
  explicit Integer(long n) : n(n) { is_literal = true; }
  explicit Integer(long n, location loc) : Expression(loc), n(n) { is_literal = true; }
  long n;

  void accept(Visitor &v) override;
};

class PositionalParameter : public Expression {
public:
  explicit PositionalParameter(PositionalParameterType ptype, long n) : ptype(ptype), n(n) {}
  explicit PositionalParameter(PositionalParameterType ptype, long n, location loc) : Expression(loc), ptype(ptype), n(n) {}
  PositionalParameterType ptype;
  long n;
  bool is_in_str = false;

  void accept(Visitor &v) override;
};

class String : public Expression {
public:
  explicit String(std::string str) : str(str) { is_literal = true; }
  explicit String(std::string str, location loc) : Expression(loc), str(str) { is_literal = true; }
  std::string str;

  void accept(Visitor &v) override;
};

class StackMode : public Expression {
public:
  explicit StackMode(std::string mode) : mode(mode) {}
  explicit StackMode(std::string mode, location loc) : Expression(loc), mode(mode) {}
  std::string mode;

  void accept(Visitor &v) override;
};

class Identifier : public Expression {
public:
  explicit Identifier(std::string ident) : ident(ident) {}
  explicit Identifier(std::string ident, location loc) : Expression(loc), ident(ident) {}
  std::string ident;

  void accept(Visitor &v) override;
};

class Builtin : public Expression {
public:
  explicit Builtin(std::string ident) : ident(is_deprecated(ident)) {
    resolve_curtask(ident);
  }
  explicit Builtin(std::string ident, location loc) : Expression(loc), ident(is_deprecated(ident)) {
    resolve_curtask(ident);
  }
  std::string ident;
  int probe_id;

  void accept(Visitor &v) override;

private:
  void resolve_curtask(std::string& ident) {
    if (ident == "curtask")
      getResolve().insert("task_struct");
  }
};

class Call : public Expression {
public:
  explicit Call(std::string &func) : func(is_deprecated(func)), vargs(nullptr) { }
  explicit Call(std::string &func, location loc) : Expression(loc), func(is_deprecated(func)), vargs(nullptr) { }
  Call(std::string &func, ExpressionList *vargs) : func(is_deprecated(func)), vargs(vargs) { }
  Call(std::string &func, ExpressionList *vargs, location loc) : Expression(loc), func(is_deprecated(func)), vargs(vargs) { }
  std::string func;
  ExpressionList *vargs;

  void accept(Visitor &v) override;
};

class Map : public Expression {
public:
  explicit Map(std::string &ident, location loc) : Expression(loc), ident(ident), vargs(nullptr) { is_map = true; }
  Map(std::string &ident, ExpressionList *vargs) : ident(ident), vargs(vargs) { is_map = true; }
  Map(std::string &ident, ExpressionList *vargs, location loc) : Expression(loc), ident(ident), vargs(vargs)
  {
    is_map = true;
    for (auto expr : *vargs)
    {
      expr->key_for_map = this;
    }
  }
  std::string ident;
  ExpressionList *vargs;
  bool skip_key_validation = false;

  void accept(Visitor &v) override;
};

class Variable : public Expression {
public:
  explicit Variable(std::string &ident) : ident(ident) { is_variable = true; }
  explicit Variable(std::string &ident, location loc) : Expression(loc), ident(ident) { is_variable = true; }
  std::string ident;

  void accept(Visitor &v) override;
};

class Binop : public Expression {
public:
  Binop(Expression *left, int op, Expression *right, location loc)
      : Expression(loc), left(left), right(right), op(op) {}
  Expression *left, *right;
  int op;

  void accept(Visitor &v) override;
};

class Unop : public Expression {
public:
 Unop(int op, Expression *expr, location loc = location())
   : Expression(loc), expr(expr), op(op), is_post_op(false) { }
  Unop(int op, Expression *expr, bool is_post_op = false, location loc = location())
    : Expression(loc), expr(expr), op(op), is_post_op(is_post_op) { }
  Expression *expr;
  int op;
  bool is_post_op;

  void accept(Visitor &v) override;
};

class FieldAccess : public Expression {
public:
  FieldAccess(Expression *expr, const std::string &field) : expr(expr), field(field) { }
  FieldAccess(Expression *expr, const std::string &field, location loc) : Expression(loc), expr(expr), field(field) { }
  Expression *expr;
  std::string field;

  void accept(Visitor &v) override;
};

class ArrayAccess : public Expression {
public:
  ArrayAccess(Expression *expr, Expression* indexpr) : expr(expr), indexpr(indexpr) { }
  ArrayAccess(Expression *expr, Expression* indexpr, location loc) : Expression(loc), expr(expr), indexpr(indexpr) { }
  Expression *expr;
  Expression *indexpr;

  void accept(Visitor &v) override;
};

class Cast : public Expression {
public:
  Cast(const std::string &type, bool is_pointer, Expression *expr)
    : cast_type(type), is_pointer(is_pointer), expr(expr) {
    getResolve().insert(type);
  }
  Cast(const std::string &type, bool is_pointer, Expression *expr, location loc)
    : Expression(loc), cast_type(type), is_pointer(is_pointer), expr(expr) {
    getResolve().insert(type);
  }
  std::string cast_type;
  bool is_pointer;
  Expression *expr;

  void accept(Visitor &v) override;
};

class Statement : public Node {
public:
  Statement() {}
  Statement(location loc) : Node(loc) {}
};
using StatementList = std::vector<Statement *>;

class ExprStatement : public Statement {
public:
  explicit ExprStatement(Expression *expr) : expr(expr) { }
  explicit ExprStatement(Expression *expr, location loc) : Statement(loc), expr(expr) { }
  Expression *expr;

  void accept(Visitor &v) override;
};

class AssignMapStatement : public Statement {
public:
 AssignMapStatement(Map *map, Expression *expr, location loc = location()) : Statement(loc), map(map), expr(expr) {
    expr->map = map;
  };
  Map *map;
  Expression *expr;

  void accept(Visitor &v) override;
};

class AssignVarStatement : public Statement {
public:
  AssignVarStatement(Variable *var, Expression *expr) : var(var), expr(expr) {
    expr->var = var;
  }
  AssignVarStatement(Variable *var, Expression *expr, location loc)
    : Statement(loc), var(var), expr(expr) { expr->var = var; }
  Variable *var;
  Expression *expr;

  void accept(Visitor &v) override;
};

class If : public Statement {
public:
  If(Expression *cond, StatementList *stmts) : cond(cond), stmts(stmts) { }
  If(Expression *cond, StatementList *stmts, StatementList *else_stmts)
    : cond(cond), stmts(stmts), else_stmts(else_stmts) { }
  Expression *cond;
  StatementList *stmts = nullptr;
  StatementList *else_stmts = nullptr;

  void accept(Visitor &v) override;
};

class Unroll : public Statement {
public:
  Unroll(long int var, StatementList *stmts) : var(var), stmts(stmts) {}

  long int var = 0;
  StatementList *stmts;

  void accept(Visitor &v) override;
};

class Predicate : public Node {
public:
  explicit Predicate(Expression *expr) : expr(expr) { }
  explicit Predicate(Expression *expr, location loc) : Node(loc), expr(expr) { }
  Expression *expr;

  void accept(Visitor &v) override;
};

class Ternary : public Expression {
public:
  Ternary(Expression *cond, Expression *left, Expression *right)
    : cond(cond), left(left), right(right) { }
  Ternary(Expression *cond, Expression *left, Expression *right, location loc)
    : Expression(loc), cond(cond), left(left), right(right) { }
  Expression *cond, *left, *right;

  void accept(Visitor &v) override;
};

class AttachPoint : public Node {
public:
  explicit AttachPoint(const std::string &provider, location loc=location())
    : Node(loc), provider(probetypeName(provider)) { }
  AttachPoint(const std::string &provider,
              const std::string &func,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), func(func), need_expansion(true) { }
  AttachPoint(const std::string &provider,
              const std::string &target,
              const std::string &func,
              bool need_expansion,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), target(target), func(func), need_expansion(need_expansion) { }
  AttachPoint(const std::string &provider,
              const std::string &target,
              const std::string &ns,
              const std::string &func,
              bool need_expansion,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), target(target), ns(ns), func(func), need_expansion(need_expansion) { }
  AttachPoint(const std::string &provider,
              const std::string &target,
              uint64_t val,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), target(target), need_expansion(true)
  {
    if (provider == "uprobe")
      address = val;
    else
      freq = val;
  }
  AttachPoint(const std::string &provider,
              const std::string &target,
              uint64_t addr,
              uint64_t len,
              const std::string &mode,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), target(target), addr(addr), len(len), mode(mode) { }
  AttachPoint(const std::string &provider,
              const std::string &target,
              const std::string &func,
              uint64_t offset,
              location loc=location())
    : Node(loc), provider(probetypeName(provider)), target(target), func(func), need_expansion(true), func_offset(offset) { }

  std::string provider;
  std::string target;
  std::string ns;
  std::string func;
  usdt_probe_entry usdt; // resolved USDT entry, used to support arguments with wildcard matches
  int freq = 0;
  uint64_t addr = 0;
  uint64_t len = 0;
  std::string mode;
  bool need_expansion = false;
  uint64_t address = 0;
  uint64_t func_offset = 0;

  void accept(Visitor &v) override;
  std::string name(const std::string &attach_point) const;

  int index(std::string name);
  void set_index(std::string name, int index);
private:
  std::map<std::string, int> index_;
};
using AttachPointList = std::vector<AttachPoint *>;

class Probe : public Node {
public:
  Probe(AttachPointList *attach_points, Predicate *pred, StatementList *stmts)
    : attach_points(attach_points), pred(pred), stmts(stmts) { }

  AttachPointList *attach_points;
  Predicate *pred;
  StatementList *stmts;

  void accept(Visitor &v) override;
  std::string name() const;
  bool need_expansion = false;        // must build a BPF program per wildcard match
  bool need_tp_args_structs = false;  // must import struct for tracepoints

  int index();
  void set_index(int index);
private:
  int index_ = 0;
};
using ProbeList = std::vector<Probe *>;

class Program : public Node {
public:
  Program(const std::string &c_definitions, ProbeList *probes)
    : c_definitions(c_definitions), probes(probes) { }
  std::string c_definitions;
  ProbeList *probes;

  void accept(Visitor &v) override;
};

class Visitor {
public:
  virtual ~Visitor() { }
  virtual void visit(Integer &integer) = 0;
  virtual void visit(PositionalParameter &integer) = 0;
  virtual void visit(String &string) = 0;
  virtual void visit(Builtin &builtin) = 0;
  virtual void visit(Identifier &identifier) = 0;
  virtual void visit(StackMode &mode) = 0;
  virtual void visit(Call &call) = 0;
  virtual void visit(Map &map) = 0;
  virtual void visit(Variable &var) = 0;
  virtual void visit(Binop &binop) = 0;
  virtual void visit(Unop &unop) = 0;
  virtual void visit(Ternary &ternary) = 0;
  virtual void visit(FieldAccess &acc) = 0;
  virtual void visit(ArrayAccess &arr) = 0;
  virtual void visit(Cast &cast) = 0;
  virtual void visit(ExprStatement &expr) = 0;
  virtual void visit(AssignMapStatement &assignment) = 0;
  virtual void visit(AssignVarStatement &assignment) = 0;
  virtual void visit(If &if_block) = 0;
  virtual void visit(Unroll &unroll) = 0;
  virtual void visit(Predicate &pred) = 0;
  virtual void visit(AttachPoint &ap) = 0;
  virtual void visit(Probe &probe) = 0;
  virtual void visit(Program &program) = 0;
};

std::string opstr(Binop &binop);
std::string opstr(Unop &unop);

} // namespace ast
} // namespace bpftrace
